Files
saladeaula.digital/layercake/layercake/dynamodb.py

713 lines
20 KiB
Python

import json
import os
from abc import ABC, abstractmethod
from base64 import urlsafe_b64decode, urlsafe_b64encode
from dataclasses import dataclass
from datetime import datetime
from ipaddress import IPv4Address
from typing import TYPE_CHECKING, Any, Type, TypedDict
from urllib.parse import quote, unquote
from uuid import UUID
from aws_lambda_powertools import Logger
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError
from .dateutils import now, timestamp
TZ = os.getenv('TZ', 'UTC')
PK = os.getenv('DYNAMODB_PARTITION_KEY', 'pk')
SK = os.getenv('DYNAMODB_SORT_KEY', 'sk')
LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25))
DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#')
logger = Logger(__name__)
serializer = TypeSerializer()
deserializer = TypeDeserializer()
def _serialize_python_type(value: Any) -> str | dict | list:
match value:
case datetime():
return value.isoformat()
case UUID():
return str(value)
case IPv4Address():
return str(value)
case list() | tuple():
return [_serialize_python_type(v) for v in value]
case dict():
return {k: _serialize_python_type(v) for k, v in value.items()}
case _:
return value
def serialize(value: dict) -> dict:
return {
k: serializer.serialize(_serialize_python_type(v)) for k, v in value.items()
}
def deserialize(value: dict) -> dict:
return {k: deserializer.deserialize(v) for k, v in value.items()}
if TYPE_CHECKING:
@dataclass
class ComposeKey(str):
"""Creates a composite key by joining string parts with a specified delimiter.
If a prefix is provided, it is added at the beginning of the key parts.
Example
-------
>>> ComposeKey(('abc', 'xyz'), prefix='examples', delimiter='#')
'examples#abc#xyz'
"""
keyparts: str | tuple[str, ...]
prefix: str | None = None
delimiter: str = '#'
else:
class ComposeKey(str):
def __new__(
cls,
keyparts: str | tuple[str, ...],
*,
prefix: str | None = None,
delimiter: str = '#',
) -> str:
if isinstance(keyparts, str):
keyparts = (keyparts,)
if prefix:
keyparts = (prefix,) + keyparts
return super().__new__(cls, delimiter.join(keyparts))
def __init__(
self,
keyparts: str | tuple[str, ...],
*,
prefix: str | None = None,
delimiter: str = '#',
) -> None:
# __init__ is used to store the parameters for later reference.
# For immutable types like str, __init__ cannot change the instance's value.
self.keyparts = keyparts
self.prefix = prefix
self.delimiter = delimiter
if TYPE_CHECKING:
@dataclass
class PrefixKey(str):
prefix: str
delimiter: str | None = '#'
else:
class PrefixKey(str):
def __new__(cls, prefix: str, delimiter: str | None = '#') -> str:
if not delimiter:
return super().__new__(cls, prefix)
return super().__new__(cls, prefix + delimiter)
def __init__(self, prefix: str, delimiter: str = '#') -> None:
# __init__ is used to store the parameters for later reference.
# For immutable types like str, __init__ cannot change the instance's value.
self.prefix = prefix
self.delimiter = delimiter
class Key(ABC, dict):
@abstractmethod
def expr_attr_name(self) -> dict: ...
@abstractmethod
def expr_attr_values(self) -> dict: ...
class PartitionKey(Key):
"""Represents a partition key for DynamoDB queries"""
def __init__(self, pk: str) -> None:
super().__init__(**{PK: pk})
def __repr__(self) -> str:
pk = self[PK]
return f'PartitionKey({pk!r})'
def expr_attr_name(self) -> dict:
return {'#pk': PK}
def expr_attr_values(self) -> dict:
return {':pk': self[PK]}
class KeyPair(Key):
"""Represents a composite key (partition key and sort key) for DynamoDB queries"""
def __init__(self, pk: str, sk: str) -> None:
super().__init__(**{PK: pk, SK: sk})
def __repr__(self) -> str:
pk, sk = self.values()
return f'KeyPair({pk!r}, {sk!r})'
def expr_attr_name(self) -> dict:
return {
'#pk': PK,
'#sk': SK,
}
def expr_attr_values(self) -> dict:
return {
':pk': self[PK],
':sk': self[SK],
}
class TransactItems:
"""
Documentation:
--------------
- https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_TransactWriteItems.html
- https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_TransactGetItems.html
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_write_items
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_get_items
"""
def __init__(self, table_name: str) -> None:
self.table_name = table_name
self.items: list[dict] = []
def put(
self,
*,
item: dict,
table_name: str | None = None,
cond_expr: str | None = None,
) -> None:
attrs: dict = {}
if cond_expr:
attrs['ConditionExpression'] = cond_expr
if not table_name:
table_name = self.table_name
self.items.append(
dict(
Put=dict(
TableName=table_name,
Item=serialize(item),
**attrs,
)
)
)
def update(
self,
*,
key: dict,
update_expr: str,
cond_expr: str | None = None,
table_name: str | None = None,
expr_attr_names: dict = {},
expr_attr_values: dict = {},
) -> None:
attrs: dict = {}
if cond_expr:
attrs['ConditionExpression'] = cond_expr
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if expr_attr_values:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
if not table_name:
table_name = self.table_name
self.items.append(
dict(
Update=dict(
TableName=table_name,
Key=serialize(key),
UpdateExpression=update_expr,
**attrs,
)
)
)
def get(
self,
*,
table_name: str | None = None,
key: dict,
expr_attr_names: str | None = None,
) -> None:
attrs: dict = {}
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if not table_name:
table_name = self.table_name
self.items.append(
dict(
Get=dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
)
)
def delete(
self,
*,
key: dict,
table_name: str | None = None,
cond_expr: str | None = None,
expr_attr_names: dict = {},
expr_attr_values: dict = {},
) -> None:
attrs: dict = {}
if cond_expr:
attrs['ConditionExpression'] = cond_expr
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if expr_attr_values:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
if not table_name:
table_name = self.table_name
self.items.append(
dict(
Delete=dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
)
)
def condition(
self,
*,
key: dict,
cond_expr: str,
table_name: str | None = None,
expr_attr_names: dict = {},
expr_attr_values: dict = {},
) -> None:
attrs: dict = {'ConditionExpression': cond_expr}
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if expr_attr_values:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
if not table_name:
table_name = self.table_name
self.items.append(
dict(
ConditionCheck=dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
)
)
class DynamoDBPersistenceLayer:
"""
Documentation:
--------------
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.query
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.get_item
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.put_item
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.update_item
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.delete_item
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_get_items
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.transact_write_items
"""
def __init__(self, table_name: str, dynamodb_client) -> None:
self.table_name = table_name
self.dynamodb_client = dynamodb_client
def query(
self,
*,
key_cond_expr: str,
expr_attr_name: dict = {},
expr_attr_values: dict = {},
start_key: dict = {},
filter_expr: str | None = None,
limit: int | None = None,
index_forward: bool = True,
) -> dict[str, Any]:
"""You must provide the name of the partition key attribute and a single value for that attribute.
Query returns all items with that partition key value.
Optionally, you can provide a sort key attribute and use a comparison operator to refine the search results.
...
A `Query` operation always returns a result set. If no matching items are found, the result set will be empty.
Queries that do not return results consume the minimum number of read capacity units for that type of read operation.
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
"""
attrs: dict = {'ScanIndexForward': index_forward}
if expr_attr_name:
attrs['ExpressionAttributeNames'] = expr_attr_name
if expr_attr_values:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
if start_key:
attrs['ExclusiveStartKey'] = start_key
if filter_expr:
attrs['FilterExpression'] = filter_expr
if limit:
attrs['Limit'] = limit
try:
res = self.dynamodb_client.query(
TableName=self.table_name,
KeyConditionExpression=key_cond_expr,
**attrs,
)
except ClientError as err:
logger.exception(err)
raise
else:
return dict(
items=[deserialize(v) for v in res.get('Items', [])],
last_key=res.get('LastEvaluatedKey', None),
)
def get_item(self, key: dict) -> dict:
"""The GetItem operation returns a set of attributes for the item with the given primary key.
If there is no matching item, GetItem does not return any data and there will be no Item element in the res.
"""
try:
res = self.dynamodb_client.get_item(
TableName=self.table_name,
Key=serialize(key),
)
except ClientError as err:
logger.exception(err)
raise
else:
return deserialize(res.get('Item', {}))
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
attrs = {}
if cond_expr:
attrs['ConditionExpression'] = cond_expr
try:
self.dynamodb_client.put_item(
TableName=self.table_name,
Item=serialize(item),
**attrs,
)
except ClientError as err:
logger.exception(err)
raise
else:
return True
def update_item(
self,
key: dict,
*,
update_expr: str,
cond_expr: str | None = None,
expr_attr_names: dict | None = None,
expr_attr_values: dict | None = None,
) -> bool:
attrs: dict = {}
if cond_expr:
attrs['ConditionExpression'] = cond_expr
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if expr_attr_values:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
try:
self.dynamodb_client.update_item(
TableName=self.table_name,
Key=serialize(key),
UpdateExpression=update_expr,
**attrs,
)
except ClientError as err:
logger.exception(err)
raise
else:
return True
def delete_item(
self,
key: dict,
*,
cond_expr: str | None = None,
expr_attr_names: dict | None = None,
expr_attr_values: dict | None = None,
) -> bool:
"""Deletes a single item in a table by primary key. You can perform a conditional delete operation that deletes
the item if it exists, or if it has an expected attribute value.
"""
attrs: dict = {}
if cond_expr:
attrs['ConditionExpression'] = cond_expr
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if expr_attr_values:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
try:
self.dynamodb_client.delete_item(
TableName=self.table_name, Key=serialize(key), **attrs
)
except ClientError as err:
logger.exception(err)
raise
else:
return True
def transact_get_items(self, transact_items: TransactItems) -> list[dict]:
try:
res = self.dynamodb_client.transact_get_items(
TransactItems=transact_items.items
)
except ClientError as err:
logger.exception(err)
raise
else:
return [deserialize(res.get('Item', {})) for res in res.get('ress', [])]
def transact_write_items(self, transact_items: TransactItems) -> bool:
try:
self.dynamodb_client.transact_write_items(
TransactItems=transact_items.items
)
except ClientError as err:
logger.exception(err)
raise
else:
return True
class MissingError(ValueError):
pass
class PaginatedResult(TypedDict):
items: list[dict]
last_key: str | None
class DynamoDBCollection:
"""
Example
-------
**Get an item using a composed sort key**
collect = DynamoDBCollection(...)
collect.get_item(
KeyPair(
'b3511b5a-cb32-4833-a373-f8223f2088d4',
ComposeKey('sergio@somosbeta.com.br', prefix='emails'),
),
)
**Get items using a composed partition key**
collect = DynamoDBCollection(...)
collect.get_items(
PartitionKey(
ComposeKey('b3511b5a-cb32-4833-a373-f8223f2088d4', prefix='logs')
),
)
**Get items using a key pair**
collect = DynamoDBCollection(...)
collect.get_items(
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', 'emails'),
)
"""
def __init__(
self,
persistence_layer: DynamoDBPersistenceLayer,
exception_cls: Type[ValueError] = MissingError,
tz: str = TZ,
) -> None:
if not issubclass(exception_cls, ValueError):
raise TypeError(
f'exception_cls must be a subclass of ValueError, got {exception_cls}'
)
self.persistence_layer = persistence_layer
self.exception_cls = exception_cls
self.tz = tz
def get_item(
self,
key: KeyPair,
path_spec: str | None = None,
raise_if_missing: bool = True,
default: Any = None,
delimiter: str = '#',
) -> Any:
exc_cls = self.exception_cls
data = self.persistence_layer.get_item(key)
if raise_if_missing and not data:
raise exc_cls(f'Item with {key} not found.')
if path_spec and data:
from glom import glom
return glom(data, path_spec, default=default)
return data or default
def put_item(
self,
key: KeyPair,
ttl: int | datetime | None = None,
**kwargs: Any,
) -> bool:
now_ = now(self.tz)
if isinstance(ttl, int):
kwargs.update(
{
'ttl': ttl,
'ttl_date': datetime.fromtimestamp(ttl, now_.tzinfo),
}
)
if isinstance(ttl, datetime):
kwargs.update(
{
'ttl': timestamp(ttl),
'ttl_date': ttl,
}
)
return self.persistence_layer.put_item(
item=key | {'create_date': now_} | kwargs
)
def delete_item(
self,
key: KeyPair,
*,
cond_expr: str | None = None,
expr_attr_names: dict | None = None,
expr_attr_values: dict | None = None,
) -> bool:
return self.persistence_layer.delete_item(
key=key,
cond_expr=cond_expr,
expr_attr_names=expr_attr_names,
expr_attr_values=expr_attr_values,
)
def get_items(
self,
key: PartitionKey | KeyPair,
*,
expr_attr_name: dict = {},
expr_attr_values: dict = {},
start_key: str | None = None,
filter_expr: str | None = None,
index_forward: bool = False,
limit: int = LIMIT,
) -> PaginatedResult:
key_cond_expr = (
'#pk = :pk AND begins_with(#sk, :sk)'
if isinstance(key, KeyPair)
else '#pk = :pk'
)
res = self.persistence_layer.query(
key_cond_expr=key_cond_expr,
expr_attr_name=key.expr_attr_name() | expr_attr_name,
expr_attr_values=key.expr_attr_values() | expr_attr_values,
filter_expr=filter_expr,
index_forward=index_forward,
limit=limit,
start_key=_startkey_b64decode(start_key) if start_key else {},
)
items = res['items']
last_key = _startkey_b64encode(res['last_key']) if res['last_key'] else None
match key.get(PK), key.get(SK):
case ComposeKey(), _: # Remove prefix from Partition Key
items = _remove_prefix(items, PK, key[PK].prefix + key[PK].delimiter)
case _, PrefixKey(): # Remove prefix from Sort Key
items = _remove_prefix(items, SK, key[SK])
return {
'items': items,
'last_key': last_key,
}
def _remove_prefix(
items: list[dict[str, Any]],
/,
key: str,
prefix: str,
) -> list[dict[str, Any]]:
"""Remove the given prefix from the value associated with key in each item."""
return [x | {key: x[key].removeprefix(prefix)} for x in items]
def _startkey_b64encode(obj: dict) -> str:
s = json.dumps(obj)
b = urlsafe_b64encode(s.encode('utf-8')).decode('utf-8')
return quote(b)
def _startkey_b64decode(s: str) -> dict:
b = unquote(s).encode('utf-8')
s = urlsafe_b64decode(b).decode('utf-8')
return json.loads(s)