713 lines
20 KiB
Python
713 lines
20 KiB
Python
import json
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from ipaddress import IPv4Address
|
|
from typing import TYPE_CHECKING, Any, Type, TypedDict
|
|
from urllib.parse import quote, unquote
|
|
from uuid import UUID
|
|
|
|
from aws_lambda_powertools import Logger
|
|
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
|
from botocore.exceptions import ClientError
|
|
|
|
from .dateutils import now, timestamp
|
|
|
|
TZ = os.getenv('TZ', 'UTC')
|
|
PK = os.getenv('DYNAMODB_PARTITION_KEY', 'pk')
|
|
SK = os.getenv('DYNAMODB_SORT_KEY', 'sk')
|
|
LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25))
|
|
DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#')
|
|
|
|
|
|
logger = Logger(__name__)
|
|
serializer = TypeSerializer()
|
|
deserializer = TypeDeserializer()
|
|
|
|
|
|
def _serialize_python_type(value: Any) -> str | dict | list:
|
|
match value:
|
|
case datetime():
|
|
return value.isoformat()
|
|
case UUID():
|
|
return str(value)
|
|
case IPv4Address():
|
|
return str(value)
|
|
case list() | tuple():
|
|
return [_serialize_python_type(v) for v in value]
|
|
case dict():
|
|
return {k: _serialize_python_type(v) for k, v in value.items()}
|
|
case _:
|
|
return value
|
|
|
|
|
|
def serialize(value: dict) -> dict:
|
|
return {
|
|
k: serializer.serialize(_serialize_python_type(v)) for k, v in value.items()
|
|
}
|
|
|
|
|
|
def deserialize(value: dict) -> dict:
|
|
return {k: deserializer.deserialize(v) for k, v in value.items()}
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class ComposeKey(str):
|
|
"""Creates a composite key by joining string parts with a specified delimiter.
|
|
If a prefix is provided, it is added at the beginning of the key parts.
|
|
|
|
Example
|
|
-------
|
|
>>> ComposeKey(('abc', 'xyz'), prefix='examples', delimiter='#')
|
|
'examples#abc#xyz'
|
|
"""
|
|
|
|
keyparts: str | tuple[str, ...]
|
|
prefix: str | None = None
|
|
delimiter: str = '#'
|
|
else:
|
|
|
|
class ComposeKey(str):
|
|
def __new__(
|
|
cls,
|
|
keyparts: str | tuple[str, ...],
|
|
*,
|
|
prefix: str | None = None,
|
|
delimiter: str = '#',
|
|
) -> str:
|
|
if isinstance(keyparts, str):
|
|
keyparts = (keyparts,)
|
|
|
|
if prefix:
|
|
keyparts = (prefix,) + keyparts
|
|
|
|
return super().__new__(cls, delimiter.join(keyparts))
|
|
|
|
def __init__(
|
|
self,
|
|
keyparts: str | tuple[str, ...],
|
|
*,
|
|
prefix: str | None = None,
|
|
delimiter: str = '#',
|
|
) -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.keyparts = keyparts
|
|
self.prefix = prefix
|
|
self.delimiter = delimiter
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class PrefixKey(str):
|
|
prefix: str
|
|
delimiter: str | None = '#'
|
|
|
|
else:
|
|
|
|
class PrefixKey(str):
|
|
def __new__(cls, prefix: str, delimiter: str | None = '#') -> str:
|
|
if not delimiter:
|
|
return super().__new__(cls, prefix)
|
|
|
|
return super().__new__(cls, prefix + delimiter)
|
|
|
|
def __init__(self, prefix: str, delimiter: str = '#') -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.prefix = prefix
|
|
self.delimiter = delimiter
|
|
|
|
|
|
class Key(ABC, dict):
|
|
@abstractmethod
|
|
def expr_attr_name(self) -> dict: ...
|
|
|
|
@abstractmethod
|
|
def expr_attr_values(self) -> dict: ...
|
|
|
|
|
|
class PartitionKey(Key):
|
|
"""Represents a partition key for DynamoDB queries"""
|
|
|
|
def __init__(self, pk: str) -> None:
|
|
super().__init__(**{PK: pk})
|
|
|
|
def __repr__(self) -> str:
|
|
pk = self[PK]
|
|
return f'PartitionKey({pk!r})'
|
|
|
|
def expr_attr_name(self) -> dict:
|
|
return {'#pk': PK}
|
|
|
|
def expr_attr_values(self) -> dict:
|
|
return {':pk': self[PK]}
|
|
|
|
|
|
class KeyPair(Key):
|
|
"""Represents a composite key (partition key and sort key) for DynamoDB queries"""
|
|
|
|
def __init__(self, pk: str, sk: str) -> None:
|
|
super().__init__(**{PK: pk, SK: sk})
|
|
|
|
def __repr__(self) -> str:
|
|
pk, sk = self.values()
|
|
return f'KeyPair({pk!r}, {sk!r})'
|
|
|
|
def expr_attr_name(self) -> dict:
|
|
return {
|
|
'#pk': PK,
|
|
'#sk': SK,
|
|
}
|
|
|
|
def expr_attr_values(self) -> dict:
|
|
return {
|
|
':pk': self[PK],
|
|
':sk': self[SK],
|
|
}
|
|
|
|
|
|
class TransactItems:
|
|
"""
|
|
Documentation:
|
|
--------------
|
|
- https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_TransactWriteItems.html
|
|
- https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_TransactGetItems.html
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_write_items
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_get_items
|
|
"""
|
|
|
|
def __init__(self, table_name: str) -> None:
|
|
self.table_name = table_name
|
|
self.items: list[dict] = []
|
|
|
|
def put(
|
|
self,
|
|
*,
|
|
item: dict,
|
|
table_name: str | None = None,
|
|
cond_expr: str | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
dict(
|
|
Put=dict(
|
|
TableName=table_name,
|
|
Item=serialize(item),
|
|
**attrs,
|
|
)
|
|
)
|
|
)
|
|
|
|
def update(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
update_expr: str,
|
|
cond_expr: str | None = None,
|
|
table_name: str | None = None,
|
|
expr_attr_names: dict = {},
|
|
expr_attr_values: dict = {},
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
dict(
|
|
Update=dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
UpdateExpression=update_expr,
|
|
**attrs,
|
|
)
|
|
)
|
|
)
|
|
|
|
def get(
|
|
self,
|
|
*,
|
|
table_name: str | None = None,
|
|
key: dict,
|
|
expr_attr_names: str | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
dict(
|
|
Get=dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
)
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
table_name: str | None = None,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict = {},
|
|
expr_attr_values: dict = {},
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
dict(
|
|
Delete=dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
)
|
|
)
|
|
|
|
def condition(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
cond_expr: str,
|
|
table_name: str | None = None,
|
|
expr_attr_names: dict = {},
|
|
expr_attr_values: dict = {},
|
|
) -> None:
|
|
attrs: dict = {'ConditionExpression': cond_expr}
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
dict(
|
|
ConditionCheck=dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
class DynamoDBPersistenceLayer:
|
|
"""
|
|
Documentation:
|
|
--------------
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.query
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.get_item
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.put_item
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.update_item
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.delete_item
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_get_items
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_write_items
|
|
"""
|
|
|
|
def __init__(self, table_name: str, dynamodb_client) -> None:
|
|
self.table_name = table_name
|
|
self.dynamodb_client = dynamodb_client
|
|
|
|
def query(
|
|
self,
|
|
*,
|
|
key_cond_expr: str,
|
|
expr_attr_name: dict = {},
|
|
expr_attr_values: dict = {},
|
|
start_key: dict = {},
|
|
filter_expr: str | None = None,
|
|
limit: int | None = None,
|
|
index_forward: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""You must provide the name of the partition key attribute and a single value for that attribute.
|
|
|
|
Query returns all items with that partition key value.
|
|
Optionally, you can provide a sort key attribute and use a comparison operator to refine the search results.
|
|
|
|
...
|
|
|
|
A `Query` operation always returns a result set. If no matching items are found, the result set will be empty.
|
|
Queries that do not return results consume the minimum number of read capacity units for that type of read operation.
|
|
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
|
"""
|
|
attrs: dict = {'ScanIndexForward': index_forward}
|
|
|
|
if expr_attr_name:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_name
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if start_key:
|
|
attrs['ExclusiveStartKey'] = start_key
|
|
|
|
if filter_expr:
|
|
attrs['FilterExpression'] = filter_expr
|
|
|
|
if limit:
|
|
attrs['Limit'] = limit
|
|
|
|
try:
|
|
res = self.dynamodb_client.query(
|
|
TableName=self.table_name,
|
|
KeyConditionExpression=key_cond_expr,
|
|
**attrs,
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return dict(
|
|
items=[deserialize(v) for v in res.get('Items', [])],
|
|
last_key=res.get('LastEvaluatedKey', None),
|
|
)
|
|
|
|
def get_item(self, key: dict) -> dict:
|
|
"""The GetItem operation returns a set of attributes for the item with the given primary key.
|
|
If there is no matching item, GetItem does not return any data and there will be no Item element in the res.
|
|
"""
|
|
|
|
try:
|
|
res = self.dynamodb_client.get_item(
|
|
TableName=self.table_name,
|
|
Key=serialize(key),
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return deserialize(res.get('Item', {}))
|
|
|
|
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
|
|
attrs = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
try:
|
|
self.dynamodb_client.put_item(
|
|
TableName=self.table_name,
|
|
Item=serialize(item),
|
|
**attrs,
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
def update_item(
|
|
self,
|
|
key: dict,
|
|
*,
|
|
update_expr: str,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
try:
|
|
self.dynamodb_client.update_item(
|
|
TableName=self.table_name,
|
|
Key=serialize(key),
|
|
UpdateExpression=update_expr,
|
|
**attrs,
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
def delete_item(
|
|
self,
|
|
key: dict,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
"""Deletes a single item in a table by primary key. You can perform a conditional delete operation that deletes
|
|
the item if it exists, or if it has an expected attribute value.
|
|
"""
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
try:
|
|
self.dynamodb_client.delete_item(
|
|
TableName=self.table_name, Key=serialize(key), **attrs
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
def transact_get_items(self, transact_items: TransactItems) -> list[dict]:
|
|
try:
|
|
res = self.dynamodb_client.transact_get_items(
|
|
TransactItems=transact_items.items
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return [deserialize(res.get('Item', {})) for res in res.get('ress', [])]
|
|
|
|
def transact_write_items(self, transact_items: TransactItems) -> bool:
|
|
try:
|
|
self.dynamodb_client.transact_write_items(
|
|
TransactItems=transact_items.items
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
|
|
class MissingError(ValueError):
|
|
pass
|
|
|
|
|
|
class PaginatedResult(TypedDict):
|
|
items: list[dict]
|
|
last_key: str | None
|
|
|
|
|
|
class DynamoDBCollection:
|
|
"""
|
|
Example
|
|
-------
|
|
**Get an item using a composed sort key**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.get_item(
|
|
KeyPair(
|
|
'b3511b5a-cb32-4833-a373-f8223f2088d4',
|
|
ComposeKey('sergio@somosbeta.com.br', prefix='emails'),
|
|
),
|
|
)
|
|
|
|
**Get items using a composed partition key**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.get_items(
|
|
PartitionKey(
|
|
ComposeKey('b3511b5a-cb32-4833-a373-f8223f2088d4', prefix='logs')
|
|
),
|
|
)
|
|
|
|
**Get items using a key pair**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.get_items(
|
|
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', 'emails'),
|
|
)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
persistence_layer: DynamoDBPersistenceLayer,
|
|
exception_cls: Type[ValueError] = MissingError,
|
|
tz: str = TZ,
|
|
) -> None:
|
|
if not issubclass(exception_cls, ValueError):
|
|
raise TypeError(
|
|
f'exception_cls must be a subclass of ValueError, got {exception_cls}'
|
|
)
|
|
|
|
self.persistence_layer = persistence_layer
|
|
self.exception_cls = exception_cls
|
|
self.tz = tz
|
|
|
|
def get_item(
|
|
self,
|
|
key: KeyPair,
|
|
path_spec: str | None = None,
|
|
raise_if_missing: bool = True,
|
|
default: Any = None,
|
|
delimiter: str = '#',
|
|
) -> Any:
|
|
exc_cls = self.exception_cls
|
|
data = self.persistence_layer.get_item(key)
|
|
|
|
if raise_if_missing and not data:
|
|
raise exc_cls(f'Item with {key} not found.')
|
|
|
|
if path_spec and data:
|
|
from glom import glom
|
|
|
|
return glom(data, path_spec, default=default)
|
|
|
|
return data or default
|
|
|
|
def put_item(
|
|
self,
|
|
key: KeyPair,
|
|
ttl: int | datetime | None = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
now_ = now(self.tz)
|
|
|
|
if isinstance(ttl, int):
|
|
kwargs.update(
|
|
{
|
|
'ttl': ttl,
|
|
'ttl_date': datetime.fromtimestamp(ttl, now_.tzinfo),
|
|
}
|
|
)
|
|
|
|
if isinstance(ttl, datetime):
|
|
kwargs.update(
|
|
{
|
|
'ttl': timestamp(ttl),
|
|
'ttl_date': ttl,
|
|
}
|
|
)
|
|
|
|
return self.persistence_layer.put_item(
|
|
item=key | {'create_date': now_} | kwargs
|
|
)
|
|
|
|
def delete_item(
|
|
self,
|
|
key: KeyPair,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
return self.persistence_layer.delete_item(
|
|
key=key,
|
|
cond_expr=cond_expr,
|
|
expr_attr_names=expr_attr_names,
|
|
expr_attr_values=expr_attr_values,
|
|
)
|
|
|
|
def get_items(
|
|
self,
|
|
key: PartitionKey | KeyPair,
|
|
*,
|
|
expr_attr_name: dict = {},
|
|
expr_attr_values: dict = {},
|
|
start_key: str | None = None,
|
|
filter_expr: str | None = None,
|
|
index_forward: bool = False,
|
|
limit: int = LIMIT,
|
|
) -> PaginatedResult:
|
|
key_cond_expr = (
|
|
'#pk = :pk AND begins_with(#sk, :sk)'
|
|
if isinstance(key, KeyPair)
|
|
else '#pk = :pk'
|
|
)
|
|
|
|
res = self.persistence_layer.query(
|
|
key_cond_expr=key_cond_expr,
|
|
expr_attr_name=key.expr_attr_name() | expr_attr_name,
|
|
expr_attr_values=key.expr_attr_values() | expr_attr_values,
|
|
filter_expr=filter_expr,
|
|
index_forward=index_forward,
|
|
limit=limit,
|
|
start_key=_startkey_b64decode(start_key) if start_key else {},
|
|
)
|
|
|
|
items = res['items']
|
|
last_key = _startkey_b64encode(res['last_key']) if res['last_key'] else None
|
|
|
|
match key.get(PK), key.get(SK):
|
|
case ComposeKey(), _: # Remove prefix from Partition Key
|
|
items = _remove_prefix(items, PK, key[PK].prefix + key[PK].delimiter)
|
|
case _, PrefixKey(): # Remove prefix from Sort Key
|
|
items = _remove_prefix(items, SK, key[SK])
|
|
|
|
return {
|
|
'items': items,
|
|
'last_key': last_key,
|
|
}
|
|
|
|
|
|
def _remove_prefix(
|
|
items: list[dict[str, Any]],
|
|
/,
|
|
key: str,
|
|
prefix: str,
|
|
) -> list[dict[str, Any]]:
|
|
"""Remove the given prefix from the value associated with key in each item."""
|
|
return [x | {key: x[key].removeprefix(prefix)} for x in items]
|
|
|
|
|
|
def _startkey_b64encode(obj: dict) -> str:
|
|
s = json.dumps(obj)
|
|
b = urlsafe_b64encode(s.encode('utf-8')).decode('utf-8')
|
|
return quote(b)
|
|
|
|
|
|
def _startkey_b64decode(s: str) -> dict:
|
|
b = unquote(s).encode('utf-8')
|
|
s = urlsafe_b64decode(b).decode('utf-8')
|
|
return json.loads(s)
|