remove prefix

This commit is contained in:
2025-03-25 22:38:30 -03:00
parent cd6fdd58ad
commit 02f0b317ae
4 changed files with 186 additions and 43 deletions

View File

@@ -1,8 +1,12 @@
import json
import os
from abc import ABC, abstractmethod
from base64 import urlsafe_b64decode, urlsafe_b64encode
from dataclasses import dataclass
from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Type, TypedDict
from typing import TYPE_CHECKING, Any, Type, TypedDict
from urllib.parse import quote, unquote
from uuid import UUID
from aws_lambda_powertools import Logger
@@ -49,34 +53,70 @@ def deserialize(value: dict) -> dict:
return {k: deserializer.deserialize(v) for k, v in value.items()}
def ComposeKey(
keyparts: str | tuple[str, ...],
*,
prefix: str | None = None,
delimiter: str = DELIMITER,
) -> str:
"""Creates a composite key by joining string parts with a specified delimiter.
If a prefix is provided, it is added at the beginning of the key parts.
if TYPE_CHECKING:
Example
-------
>>> ComposeKey(('abc', 'xyz'), prefix='examples', delimiter='#')
'examples#abc#xyz'
"""
@dataclass
class ComposeKey(str):
"""Creates a composite key by joining string parts with a specified delimiter.
If a prefix is provided, it is added at the beginning of the key parts.
if not prefix and not isinstance(keyparts, tuple):
return keyparts
Example
-------
>>> ComposeKey(('abc', 'xyz'), prefix='examples', delimiter='#')
'examples#abc#xyz'
"""
if isinstance(keyparts, str):
keyparts = (keyparts,)
keyparts: str | tuple[str, ...]
prefix: str | None = None
delimiter: str = '#'
else:
if prefix:
keyparts = (prefix,) + keyparts
class ComposeKey(str):
def __new__(
cls,
keyparts: str | tuple[str, ...],
*,
prefix: str | None = None,
delimiter: str = '#',
) -> str:
if isinstance(keyparts, str):
keyparts = (keyparts,)
return delimiter.join(keyparts)
if prefix:
keyparts = (prefix,) + keyparts
return super().__new__(cls, delimiter.join(keyparts))
def __init__(
self,
keyparts: str | tuple[str, ...],
*,
prefix: str | None = None,
delimiter: str = '#',
) -> None:
# __init__ is used to store the parameters for later reference.
# For immutable types like str, __init__ cannot change the instance's value.
self.keyparts = keyparts
self.prefix = prefix
self.delimiter = delimiter
class PrimaryKey(ABC, dict):
if TYPE_CHECKING:
@dataclass
class PrefixKey(str):
prefix: str
else:
class PrefixKey(str):
def __init__(self, prefix: str | None = None) -> None:
# __init__ is used to store the parameters for later reference.
# For immutable types like str, __init__ cannot change the instance's value.
self.prefix = prefix
class Key(ABC, dict):
@abstractmethod
def expr_attr_name(self) -> dict: ...
@@ -84,11 +124,11 @@ class PrimaryKey(ABC, dict):
def expr_attr_values(self) -> dict: ...
class PartitionKey(PrimaryKey):
class PartitionKey(Key):
"""Represents a partition key for DynamoDB queries"""
def __init__(self, pk: str) -> None:
super().__init__(**{PK: pk})
super().__init__(**{PK: pk, SK: None})
def expr_attr_name(self) -> dict:
return {'#pk': PK}
@@ -97,7 +137,7 @@ class PartitionKey(PrimaryKey):
return {':pk': self[PK]}
class KeyPair(PrimaryKey):
class KeyPair(Key):
"""Represents a composite key (partition key and sort key) for DynamoDB queries"""
def __init__(self, pk: str, sk: str) -> None:
@@ -341,7 +381,7 @@ class DynamoDBPersistenceLayer:
attrs['Limit'] = limit
try:
response = self.dynamodb_client.query(
res = self.dynamodb_client.query(
TableName=self.table_name,
KeyConditionExpression=key_cond_expr,
**attrs,
@@ -351,17 +391,17 @@ class DynamoDBPersistenceLayer:
raise
else:
return dict(
items=[deserialize(v) for v in response.get('Items', [])],
last_key=response.get('LastEvaluatedKey', None),
items=[deserialize(v) for v in res.get('Items', [])],
last_key=res.get('LastEvaluatedKey', None),
)
def get_item(self, key: dict) -> dict:
"""The GetItem operation returns a set of attributes for the item with the given primary key.
If there is no matching item, GetItem does not return any data and there will be no Item element in the response.
If there is no matching item, GetItem does not return any data and there will be no Item element in the res.
"""
try:
response = self.dynamodb_client.get_item(
res = self.dynamodb_client.get_item(
TableName=self.table_name,
Key=serialize(key),
)
@@ -369,7 +409,7 @@ class DynamoDBPersistenceLayer:
logger.exception(err)
raise
else:
return deserialize(response.get('Item', {}))
return deserialize(res.get('Item', {}))
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
attrs = {}
@@ -456,17 +496,14 @@ class DynamoDBPersistenceLayer:
def transact_get_items(self, transact_items: TransactItems) -> list[dict]:
try:
response = self.dynamodb_client.transact_get_items(
res = self.dynamodb_client.transact_get_items(
TransactItems=transact_items.items
)
except ClientError as err:
logger.exception(err)
raise
else:
return [
deserialize(response.get('Item', {}))
for response in response.get('Responses', [])
]
return [deserialize(res.get('Item', {})) for res in res.get('ress', [])]
def transact_write_items(self, transact_items: TransactItems) -> bool:
try:
@@ -516,7 +553,7 @@ class DynamoDBCollection:
collect = DynamoDBCollection(...)
collect.get_items(
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', 'emails),
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', 'emails'),
)
"""
@@ -619,17 +656,37 @@ class DynamoDBCollection:
expr_attr_name.update(key.expr_attr_name())
expr_attr_values.update(key.expr_attr_values())
response = self.persistence_layer.query(
res = self.persistence_layer.query(
key_cond_expr=key_cond_expr,
expr_attr_name=expr_attr_name,
expr_attr_values=expr_attr_values,
filter_expr=filter_expr,
index_forward=index_forward,
limit=limit,
# start_key=start_key if start_key else {},
start_key=_startkey_b64decode(start_key) if start_key else {},
)
items = res['items']
last_key = _startkey_b64encode(res['last_key']) if res['last_key'] else None
# Remove prefix from Sort Key if `key[SK]` is a PrefixKey
if isinstance(key[SK], PrefixKey):
prefix = key[SK].prefix
items = [item | {SK: item[SK].removeprefix(prefix)} for item in items]
return {
'items': response['items'],
'last_key': response['last_key'] if 'last_key' in response else None,
'items': items,
'last_key': last_key,
}
def _startkey_b64encode(obj: dict) -> str:
s = json.dumps(obj)
b = urlsafe_b64encode(s.encode('utf-8')).decode('utf-8')
return quote(b)
def _startkey_b64decode(s: str) -> dict:
b = unquote(s).encode('utf-8')
s = urlsafe_b64decode(b).decode('utf-8')
return json.loads(s)