1261 lines
36 KiB
Python
1261 lines
36 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
|
from dataclasses import dataclass
|
|
from datetime import date, datetime
|
|
from ipaddress import IPv4Address
|
|
from typing import TYPE_CHECKING, Any, Self, Type, TypedDict
|
|
from urllib.parse import quote, unquote
|
|
from uuid import UUID
|
|
|
|
from boto3.dynamodb.table import BatchWriter
|
|
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
|
from botocore.exceptions import ClientError
|
|
from glom import glom
|
|
|
|
from .dateutils import timestamp
|
|
from .funcs import omit, pick
|
|
|
|
TZ = os.getenv('TZ', 'UTC')
|
|
PK = os.getenv('DYNAMODB_PARTITION_KEY', 'pk')
|
|
SK = os.getenv('DYNAMODB_SORT_KEY', 'sk')
|
|
LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25))
|
|
DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#')
|
|
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
|
|
|
|
|
|
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
|
|
|
|
logger = logging.getLogger(__name__)
|
|
serializer = TypeSerializer()
|
|
deserializer = TypeDeserializer()
|
|
|
|
|
|
def _serialize_to_basic_types(data: Any) -> str | dict | set | list:
|
|
match data:
|
|
case datetime():
|
|
return data.isoformat()
|
|
|
|
case date():
|
|
return data.isoformat()
|
|
|
|
case UUID():
|
|
return str(data)
|
|
|
|
case IPv4Address():
|
|
return str(data)
|
|
|
|
case tuple() | list():
|
|
if not data:
|
|
return []
|
|
|
|
serialized = [_serialize_to_basic_types(v) for v in data]
|
|
|
|
if any(isinstance(v, (dict, list)) for v in serialized):
|
|
return serialized
|
|
|
|
try:
|
|
return set(serialized)
|
|
except TypeError:
|
|
return serialized
|
|
|
|
case set():
|
|
if not data:
|
|
return []
|
|
|
|
return set(_serialize_to_basic_types(v) for v in data)
|
|
|
|
case dict():
|
|
return {
|
|
k: _serialize_to_basic_types(v)
|
|
for k, v in data.items()
|
|
if v is not None
|
|
}
|
|
|
|
case _:
|
|
return data
|
|
|
|
|
|
def serialize(data: dict) -> dict:
|
|
return {
|
|
k: serializer.serialize(_serialize_to_basic_types(v))
|
|
for k, v in data.items()
|
|
if v is not None
|
|
}
|
|
|
|
|
|
def deserialize(data: dict) -> dict:
|
|
return {k: deserializer.deserialize(v) for k, v in data.items()}
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class ComposeKey(str):
|
|
keyparts: str | tuple[str, ...]
|
|
prefix: str | None = None
|
|
delimiter: str = '#'
|
|
else:
|
|
|
|
class ComposeKey(str):
|
|
"""Creates a composite key by joining string parts with a specified delimiter.
|
|
If a prefix is provided, it is added at the beginning of the key parts.
|
|
|
|
Example
|
|
-------
|
|
ComposeKey(('abc', 'xyz'), prefix='examples', delimiter='#')
|
|
"""
|
|
|
|
def __new__(
|
|
cls,
|
|
keyparts: str | tuple[str, ...],
|
|
/,
|
|
prefix: str | None = None,
|
|
*,
|
|
delimiter: str = '#',
|
|
) -> str:
|
|
warnings.warn(
|
|
'ComposeKey() is deprecated and will be removed in the future.',
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
if isinstance(keyparts, str):
|
|
keyparts = (keyparts,)
|
|
|
|
if prefix:
|
|
keyparts = (prefix,) + keyparts
|
|
|
|
return super().__new__(cls, delimiter.join(keyparts))
|
|
|
|
def __init__(
|
|
self,
|
|
keyparts: str | tuple[str, ...],
|
|
/,
|
|
prefix: str | None = None,
|
|
*,
|
|
delimiter: str = '#',
|
|
) -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.keyparts = keyparts
|
|
self.prefix = prefix
|
|
self.delimiter = delimiter
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class PrefixKey(str):
|
|
prefix: str
|
|
delimiter: str | None = '#'
|
|
|
|
else:
|
|
|
|
class PrefixKey(str):
|
|
def __new__(cls, prefix: str, delimiter: str | None = '#') -> str:
|
|
warnings.warn(
|
|
'PrefixKey() is deprecated and will be removed in the future.',
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
if not delimiter:
|
|
return super().__new__(cls, prefix)
|
|
|
|
return super().__new__(cls, prefix + delimiter)
|
|
|
|
def __init__(self, prefix: str, delimiter: str = '#') -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.prefix = prefix
|
|
self.delimiter = delimiter
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class SortKey(str):
|
|
"""
|
|
SortKey encapsulates the sort key value and optionally stores additional
|
|
attributes for nested data extraction.
|
|
|
|
Parameters
|
|
----------
|
|
sk: str
|
|
The sort key value.
|
|
path_spec: str, optional
|
|
Optional specification for nested data extraction.
|
|
rename_key : str, optional
|
|
If provided, renames the sort key in the output.
|
|
remove_prefix: str, optional
|
|
Optional prefix to remove from the key when forming the result dict.
|
|
"""
|
|
|
|
sk: str
|
|
path_spec: str | None = None
|
|
rename_key: str | None = None
|
|
remove_prefix: str | None = None
|
|
|
|
else:
|
|
|
|
class SortKey(str):
|
|
"""
|
|
SortKey encapsulates the sort key value and optionally stores additional
|
|
attributes for nested data extraction.
|
|
|
|
Parameters
|
|
----------
|
|
sk: str
|
|
The sort key value.
|
|
path_spec: str, optional
|
|
Optional specification for nested data extraction.
|
|
rename_key : str, optional
|
|
If provided, renames the sort key in the output.
|
|
remove_prefix: str, optional
|
|
Optional prefix to remove from the key when forming the result dict.
|
|
"""
|
|
|
|
def __new__(
|
|
cls,
|
|
sk: str,
|
|
*,
|
|
path_spec: str | None = None,
|
|
rename_key: str | None = None,
|
|
remove_prefix: str | None = None,
|
|
) -> str:
|
|
return super().__new__(cls, sk)
|
|
|
|
def __init__(
|
|
self,
|
|
sk: str,
|
|
*,
|
|
path_spec: str | None = None,
|
|
rename_key: str | None = None,
|
|
remove_prefix: str | None = None,
|
|
) -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.sk = sk
|
|
self.path_spec = path_spec
|
|
self.rename_key = rename_key
|
|
self.remove_prefix = remove_prefix
|
|
|
|
if remove_prefix:
|
|
warnings.warn(
|
|
'SortKey.remove_prefix() is deprecated and will be removed in the future.',
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
|
|
|
|
class Key(ABC, dict):
|
|
@abstractmethod
|
|
def expr_attr_name(self) -> dict: ...
|
|
|
|
@abstractmethod
|
|
def expr_attr_values(self) -> dict: ...
|
|
|
|
|
|
class PartitionKey(Key):
|
|
"""Represents a partition key for DynamoDB queries"""
|
|
|
|
def __init__(self, pk: str) -> None:
|
|
super().__init__(**{PK: pk})
|
|
|
|
def __repr__(self) -> str:
|
|
pk = self[PK]
|
|
return f'PartitionKey({pk!r})'
|
|
|
|
def expr_attr_name(self) -> dict:
|
|
return {'#pk': PK}
|
|
|
|
def expr_attr_values(self) -> dict:
|
|
return {':pk': self[PK]}
|
|
|
|
|
|
class KeyPair(Key):
|
|
"""Represents a composite key (partition key and sort key) for DynamoDB queries"""
|
|
|
|
def __init__(
|
|
self,
|
|
pk: str,
|
|
sk: str,
|
|
*,
|
|
rename_key: str | None = None,
|
|
table_name: str | None = None,
|
|
) -> None:
|
|
"""
|
|
Initializes a composite key using partition and sort key.
|
|
|
|
Parameters
|
|
----------
|
|
pk : str
|
|
The partition key.
|
|
sk : str
|
|
The sort key.
|
|
rename_key : str, optional
|
|
If provided, renames the sort key in the output.
|
|
table_name : str, optional
|
|
"""
|
|
|
|
super().__init__(**{PK: pk, SK: sk})
|
|
self._rename_key = rename_key
|
|
self._table_name = table_name
|
|
|
|
@property
|
|
def rename_key(self) -> str | None:
|
|
return self._rename_key
|
|
|
|
@property
|
|
def table_name(self) -> str | None:
|
|
return self._table_name
|
|
|
|
def __repr__(self) -> str:
|
|
pk, sk, *_ = self.values()
|
|
return f'KeyPair({pk!r}, {sk!r})'
|
|
|
|
def expr_attr_name(self) -> dict:
|
|
return {
|
|
'#pk': PK,
|
|
'#sk': SK,
|
|
}
|
|
|
|
def expr_attr_values(self) -> dict:
|
|
return {
|
|
':pk': self[PK],
|
|
':sk': self[SK],
|
|
}
|
|
|
|
@classmethod
|
|
def parse_obj(cls, obj: Any) -> Self | None:
|
|
if not obj:
|
|
return None
|
|
|
|
match obj:
|
|
case dict():
|
|
pair = pick((PK, SK), obj).values() # Gotta keep them in order
|
|
case _:
|
|
pair = obj
|
|
|
|
return cls(*pair)
|
|
|
|
def __add__(self, other: Self) -> 'KeyChain':
|
|
return KeyChain(pairs=(self, other))
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class KeyChain:
|
|
pairs: tuple[KeyPair, ...] = ()
|
|
|
|
def __add__(self, other: KeyPair) -> 'KeyChain':
|
|
if not isinstance(other, KeyPair):
|
|
raise TypeError('Can only add a KeyPair to a KeyChain')
|
|
|
|
if other not in self.pairs:
|
|
pairs = self.pairs + (other,)
|
|
else:
|
|
pairs = self.pairs
|
|
|
|
return KeyChain(pairs=pairs)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TransactKey:
|
|
"""
|
|
Example
|
|
-------
|
|
TransactKey('e9bb7dc6-c7b2-4d34-8931-d298353758ec')
|
|
+ SortKey('0')
|
|
+ SortKey('tenant')
|
|
"""
|
|
|
|
pk: str
|
|
table_name: str | None = None
|
|
pairs: tuple[KeyPair, ...] = ()
|
|
|
|
def __add__(self, other: SortKey | KeyPair) -> 'TransactKey':
|
|
if not isinstance(other, (SortKey, KeyPair)):
|
|
raise TypeError('Can only add a SortKey and KeyPair to a TransactKey')
|
|
|
|
if isinstance(other, SortKey):
|
|
other = KeyPair(self.pk, other, table_name=self.table_name)
|
|
|
|
if other not in self.pairs:
|
|
pairs = self.pairs + (other,)
|
|
else:
|
|
pairs = self.pairs
|
|
|
|
return TransactKey(
|
|
pk=self.pk,
|
|
table_name=self.table_name,
|
|
pairs=pairs,
|
|
)
|
|
|
|
|
|
class TransactionCanceledReason(TypedDict):
|
|
code: str
|
|
message: str
|
|
operation: dict
|
|
|
|
|
|
class TransactionCanceledException(Exception):
|
|
def __init__(self, msg: str = '', reasons: list = []) -> None:
|
|
super().__init__(msg)
|
|
self.msg = msg
|
|
self.reasons: list[TransactionCanceledReason] = reasons
|
|
|
|
|
|
class TransactOperation:
|
|
def __init__(
|
|
self,
|
|
operation: dict,
|
|
exc_cls: type[Exception] | None = None,
|
|
) -> None:
|
|
self.operation = operation
|
|
self.exc_cls = exc_cls
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from mypy_boto3_dynamodb.client import DynamoDBClient
|
|
else:
|
|
DynamoDBClient = object
|
|
|
|
|
|
class TransactWriter:
|
|
def __init__(
|
|
self,
|
|
table_name: str,
|
|
*,
|
|
flush_amount: int,
|
|
client: DynamoDBClient,
|
|
) -> None:
|
|
self._table_name = table_name
|
|
self._items_buffer: list[TransactOperation] = []
|
|
self._flush_amount = flush_amount
|
|
self._client = client
|
|
|
|
def __enter__(self) -> Self:
|
|
return self
|
|
|
|
def __exit__(self, *exc_details) -> None:
|
|
# When we exit, we need to keep flushing whatever's left
|
|
# until there's nothing left in our items buffer.
|
|
while self._items_buffer:
|
|
self._flush()
|
|
|
|
def put(
|
|
self,
|
|
*,
|
|
item: dict,
|
|
table_name: str | None = None,
|
|
cond_expr: str | None = None,
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
self._add_op_and_process(
|
|
TransactOperation(
|
|
{
|
|
'Put': dict(
|
|
TableName=table_name or self._table_name,
|
|
Item=serialize(item),
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
),
|
|
)
|
|
|
|
def update(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
update_expr: str,
|
|
cond_expr: str | None = None,
|
|
table_name: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
self._add_op_and_process(
|
|
TransactOperation(
|
|
{
|
|
'Update': dict(
|
|
TableName=table_name or self._table_name,
|
|
Key=serialize(key),
|
|
UpdateExpression=update_expr,
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
)
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
table_name: str | None = None,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
self._add_op_and_process(
|
|
TransactOperation(
|
|
{
|
|
'Delete': dict(
|
|
TableName=table_name or self._table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
),
|
|
)
|
|
|
|
def condition(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
cond_expr: str,
|
|
table_name: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {'ConditionExpression': cond_expr}
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
self._add_op_and_process(
|
|
TransactOperation(
|
|
{
|
|
'ConditionCheck': dict(
|
|
TableName=table_name or self._table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
)
|
|
)
|
|
|
|
def _add_op_and_process(self, op: TransactOperation) -> None:
|
|
self._items_buffer.append(op)
|
|
self._flush_if_needed()
|
|
|
|
def _flush_if_needed(self) -> None:
|
|
if len(self._items_buffer) >= self._flush_amount:
|
|
self._flush()
|
|
|
|
def _flush(self) -> bool:
|
|
items_to_send = self._items_buffer[: self._flush_amount]
|
|
self._items_buffer = self._items_buffer[self._flush_amount :]
|
|
|
|
try:
|
|
self._client.transact_write_items(
|
|
TransactItems=[item.operation for item in items_to_send] # type: ignore
|
|
)
|
|
except ClientError as err:
|
|
error_msg = glom(err, 'response.Error.Message', default='')
|
|
cancellations = err.response.get('CancellationReasons', [])
|
|
reasons = []
|
|
|
|
for idx, reason in enumerate(cancellations):
|
|
if 'Message' not in reason:
|
|
continue
|
|
|
|
item = items_to_send[idx]
|
|
|
|
if item.exc_cls:
|
|
raise item.exc_cls(error_msg)
|
|
|
|
reasons.append(
|
|
{
|
|
'code': reason.get('Code'),
|
|
'message': reason.get('Message'),
|
|
'operation': item.operation,
|
|
}
|
|
)
|
|
|
|
raise TransactionCanceledException(error_msg, reasons)
|
|
else:
|
|
return True
|
|
|
|
|
|
class DynamoDBPersistenceLayer:
|
|
def __init__(self, table_name: str, client: DynamoDBClient) -> None:
|
|
self.table_name = table_name
|
|
self.client = client
|
|
|
|
def query(
|
|
self,
|
|
*,
|
|
key_cond_expr: str,
|
|
expr_attr_name: dict = {},
|
|
expr_attr_values: dict = {},
|
|
start_key: dict = {},
|
|
filter_expr: str | None = None,
|
|
projection_expr: str | None = None,
|
|
limit: int | None = None,
|
|
index_forward: bool = True,
|
|
table_name: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""You must provide the name of the partition key attribute
|
|
and a single value for that attribute.
|
|
|
|
Query returns all items with that partition key value.
|
|
Optionally, you can provide a sort key attribute and use a comparison operator
|
|
to refine the search results.
|
|
|
|
...
|
|
|
|
A `Query` operation always returns a result set. If no matching items are found,
|
|
the result set will be empty.
|
|
Queries that do not return results consume the minimum number
|
|
of read capacity units for that type of read operation.
|
|
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
|
"""
|
|
attrs: dict = {
|
|
'TableName': table_name or self.table_name,
|
|
'KeyConditionExpression': key_cond_expr,
|
|
'ScanIndexForward': index_forward,
|
|
}
|
|
|
|
if expr_attr_name:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_name
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if start_key:
|
|
attrs['ExclusiveStartKey'] = start_key
|
|
|
|
if filter_expr:
|
|
attrs['FilterExpression'] = filter_expr
|
|
|
|
if projection_expr:
|
|
attrs['ProjectionExpression'] = projection_expr
|
|
|
|
if limit:
|
|
attrs['Limit'] = limit
|
|
|
|
try:
|
|
r = self.client.query(**attrs)
|
|
except ClientError as err:
|
|
raise err
|
|
else:
|
|
return dict(
|
|
items=[deserialize(v) for v in r.get('Items', [])],
|
|
last_key=r.get('LastEvaluatedKey', None),
|
|
)
|
|
|
|
def get_item(self, key: dict, table_name: str | None = None) -> dict:
|
|
"""The GetItem operation returns a set of attributes for the item
|
|
with the given primary key.
|
|
|
|
If there is no matching item, GetItem does not return any data and
|
|
there will be no Item element in the response.
|
|
"""
|
|
attrs = {
|
|
'TableName': table_name or self.table_name,
|
|
'Key': serialize(key),
|
|
}
|
|
|
|
try:
|
|
r = self.client.get_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
raise err
|
|
else:
|
|
return deserialize(r.get('Item', {}))
|
|
|
|
def put_item(
|
|
self,
|
|
item: dict,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
table_name: str | None = None,
|
|
) -> bool:
|
|
attrs = {
|
|
'TableName': table_name or self.table_name,
|
|
'Item': serialize(item),
|
|
}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
try:
|
|
self.client.put_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
raise err
|
|
else:
|
|
return True
|
|
|
|
def update_item(
|
|
self,
|
|
key: dict,
|
|
*,
|
|
update_expr: str,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
table_name: str | None = None,
|
|
) -> bool:
|
|
attrs: dict = {
|
|
'TableName': table_name or self.table_name,
|
|
'Key': serialize(key),
|
|
'UpdateExpression': update_expr,
|
|
}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
try:
|
|
self.client.update_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
raise err
|
|
else:
|
|
return True
|
|
|
|
def delete_item(
|
|
self,
|
|
key: dict,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
table_name: str | None = None,
|
|
) -> bool:
|
|
"""Deletes a single item in a table by primary key. You can perform
|
|
a conditional delete operation that deletes the item if it exists,
|
|
or if it has an expected attribute value.
|
|
"""
|
|
attrs: dict = {
|
|
'TableName': table_name or self.table_name,
|
|
'Key': serialize(key),
|
|
}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
try:
|
|
self.client.delete_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
raise err
|
|
else:
|
|
return True
|
|
|
|
@property
|
|
def collection(self) -> 'DynamoDBCollection':
|
|
return DynamoDBCollection(self)
|
|
|
|
def transact_writer(
|
|
self,
|
|
flush_amount: int = 50,
|
|
table_name: str | None = None,
|
|
) -> TransactWriter:
|
|
return TransactWriter(
|
|
table_name=table_name or self.table_name,
|
|
client=self.client,
|
|
flush_amount=flush_amount,
|
|
)
|
|
|
|
def batch_writer(
|
|
self,
|
|
table_name: str | None = None,
|
|
overwrite_by_pkeys: list[str] = [],
|
|
) -> BatchWriter:
|
|
"""Create a batch writer object.
|
|
|
|
This method creates a context manager for writing
|
|
objects to Amazon DynamoDB in batch.
|
|
|
|
The batch writer will automatically handle buffering and sending items
|
|
in batches. In addition, the batch writer will also automatically
|
|
handle any unprocessed items and resend them as needed. All you need
|
|
to do is call ``put_item`` for any items you want to add, and
|
|
``delete_item`` for any items you want to delete.
|
|
|
|
Example usage::
|
|
|
|
with table.batch_writer() as batch:
|
|
for _ in range(1000000):
|
|
batch.put_item(Item={'HashKey': '...',
|
|
'Otherstuff': '...'})
|
|
# You can also delete_items in a batch.
|
|
batch.delete_item(Key={'HashKey': 'SomeHashKey'})
|
|
|
|
See Also
|
|
--------
|
|
DynamoDB.Table.batch_writer https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/table/batch_writer.html#DynamoDB.Table.batch_writer
|
|
"""
|
|
return BatchWriter(
|
|
table_name=table_name or self.table_name,
|
|
client=self.client,
|
|
overwrite_by_pkeys=overwrite_by_pkeys,
|
|
)
|
|
|
|
|
|
class MissingError(Exception):
|
|
pass
|
|
|
|
|
|
class PaginatedResult(TypedDict):
|
|
items: list[dict]
|
|
last_key: str | None
|
|
|
|
|
|
class DynamoDBCollection:
|
|
"""
|
|
DynamoDBCollection provides a high-level abstraction for performing common
|
|
CRUD operations and queries on a DynamoDB table.
|
|
|
|
It leverages an underlying persistence layer to handle
|
|
serialization and deserialization of data, key composition, transaction operations,
|
|
and TTL management.
|
|
|
|
This collection class simplifies interaction with DynamoDB items, allowing users to:
|
|
- Retrieve a single item or multiple items via transactions.
|
|
- Insert (put) items with optional TTL (time-to-live) settings.
|
|
- Delete items based on keys and conditions.
|
|
- Query items using partition keys or composite key pairs with
|
|
optional filtering and pagination.
|
|
|
|
Parameters
|
|
----------
|
|
persistence_layer: DynamoDBPersistenceLayer
|
|
The persistence layer instance responsible for direct DynamoDB operations.
|
|
exc_cls: Type[Exception], optional
|
|
The exception class to be raised when a requested item is not found.
|
|
tz: str, optional
|
|
The timezone identifier used for date/time operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
persistence_layer: DynamoDBPersistenceLayer,
|
|
/,
|
|
exc_cls: Type[Exception] = MissingError,
|
|
tz: str = TZ,
|
|
) -> None:
|
|
self.persistence_layer = persistence_layer
|
|
self.exc_cls = exc_cls
|
|
self.tz = tz
|
|
|
|
def get_item(
|
|
self,
|
|
key: Key,
|
|
/,
|
|
raise_on_error: bool = True,
|
|
exc_cls: Type[Exception] | None = None,
|
|
default: Any = None,
|
|
) -> Any:
|
|
"""Get an item with the given key.
|
|
|
|
Example
|
|
-------
|
|
**Get an item using a composed sort key**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.get_item(
|
|
KeyPair(
|
|
'b3511b5a-cb32-4833-a373-f8223f2088d4',
|
|
ComposeKey('username@domain.com', prefix='emails'),
|
|
),
|
|
)
|
|
|
|
Parameters
|
|
----------
|
|
key: Key
|
|
Key of the item to be retrieved.
|
|
raise_on_error: bool, optional
|
|
If True, raises an exception when the item is not found.
|
|
exc_cls: Type[Exception], optional
|
|
Exception class to be used if the item is not found.
|
|
default: Any, optional
|
|
Default value returned if the item is not found.
|
|
|
|
Returns
|
|
-------
|
|
Any
|
|
Data of the retrieved item or the default value if not found.
|
|
|
|
Raises
|
|
------
|
|
Exception
|
|
Raises the provided exception if the item is not found
|
|
and raise_on_error is True.
|
|
"""
|
|
table_name = getattr(key, 'table_name', None)
|
|
path_spec = getattr(key[SK], 'path_spec', None)
|
|
r = self.persistence_layer.get_item(key, table_name)
|
|
|
|
if raise_on_error and not r:
|
|
exc_cls = exc_cls or self.exc_cls
|
|
raise exc_cls(f'Item with {key} not found.')
|
|
|
|
if path_spec and r:
|
|
from glom import glom
|
|
|
|
return glom(r, path_spec, default=default)
|
|
|
|
return r or default
|
|
|
|
def put_item(
|
|
self,
|
|
key: Key,
|
|
*,
|
|
ttl: int | datetime | None = None,
|
|
cond_expr: str | None = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
"""Creates a new item, or replaces an old item with a new item.
|
|
|
|
Parameters
|
|
----------
|
|
key: Key
|
|
Key for the item to be inserted or updated.
|
|
ttl: int or datetime, optional
|
|
Time-to-live for the item, specified as a timestamp integer
|
|
or datetime object.
|
|
cond_expr: str, optional
|
|
Conditional expression for deletion.
|
|
**kwargs
|
|
Additional data to be stored with the item.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the operation is successful, False otherwise.
|
|
"""
|
|
table_name = getattr(key, 'table_name', None)
|
|
|
|
if isinstance(ttl, int):
|
|
kwargs.update({'ttl': ttl})
|
|
|
|
if isinstance(ttl, datetime):
|
|
kwargs.update({'ttl': timestamp(ttl)})
|
|
|
|
return self.persistence_layer.put_item(
|
|
item=key | kwargs,
|
|
cond_expr=cond_expr,
|
|
table_name=table_name,
|
|
)
|
|
|
|
def delete_item(
|
|
self,
|
|
key: Key,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
"""Deletes a single item in a table by key.
|
|
|
|
Parameters
|
|
----------
|
|
key: Key
|
|
Key of the item to be deleted.
|
|
cond_expr: str, optional
|
|
Conditional expression for deletion.
|
|
expr_attr_names: dict, optional
|
|
Mapping of attribute names for the expression.
|
|
expr_attr_values: dict, optional
|
|
Mapping of attribute values for the expression.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the item is successfully deleted, False otherwise.
|
|
"""
|
|
table_name = getattr(key, 'table_name', None)
|
|
|
|
return self.persistence_layer.delete_item(
|
|
key=key,
|
|
cond_expr=cond_expr,
|
|
expr_attr_names=expr_attr_names,
|
|
expr_attr_values=expr_attr_values,
|
|
table_name=table_name,
|
|
)
|
|
|
|
def get_items(
|
|
self,
|
|
key: TransactKey | KeyChain,
|
|
flatten_top: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""Get multiple items via a transaction based on the provided TransactKey.
|
|
|
|
Example
|
|
-------
|
|
**Get items using chained sort keys**
|
|
|
|
key = (
|
|
TransactKey('b3511b5a-cb32-4833-a373-f8223f2088d4')
|
|
+ SortKey('sk-1')
|
|
+ SortKey('sk-2')
|
|
)
|
|
collect = DynamoDBCollection(...)
|
|
items = collect.get_items(key)
|
|
|
|
**Get items using chained key pairs**
|
|
|
|
key = (
|
|
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', '0')
|
|
+ KeyPair('cpf', '07879819908')
|
|
+ KeyPair('email', 'user@example.com')
|
|
)
|
|
collect = DynamoDBCollection(...)
|
|
items = collect.get_items(key)
|
|
|
|
Parameters
|
|
----------
|
|
key: TransactKey or KeyChain
|
|
A `TransactKey` is used when you want to define a partition key (`pk`)
|
|
and append multiple `SortKey` instances using the `+` operator. Each
|
|
`SortKey` is internally paired with the `pk` to form a `KeyPair`.
|
|
|
|
Alternatively, a `KeyChain` can be created by chaining two or more `KeyPair`
|
|
instances using the `+` operator. For example:
|
|
`KeyPair(pk1, sk1) + KeyPair(pk2, sk2)` creates a `KeyChain` with two pairs.
|
|
Further additions to the chain extend it.
|
|
|
|
If no sort keys (in `TransactKey`) or no key pairs (in `KeyChain`)
|
|
are provided, the operation is skipped.
|
|
|
|
flatten_top: bool, optional
|
|
Determines whether the first nested item in the transaction result
|
|
should be flattened,
|
|
i.e., extracted to serve as the primary item at the top level of
|
|
the returned dict.
|
|
If True, the nested item is promoted to the top level.
|
|
|
|
Returns
|
|
-------
|
|
dict[str, Any]
|
|
A dict of items retrieved from the transaction.
|
|
"""
|
|
|
|
# If no sort key is provided, the query is skipped
|
|
if not key.pairs:
|
|
return {}
|
|
|
|
pairs = key.pairs[1:] if flatten_top else key.pairs
|
|
client = self.persistence_layer.client
|
|
table_name = self.persistence_layer.table_name
|
|
|
|
# Add a get operation for each key for the transaction
|
|
transact_items = [
|
|
{
|
|
'Get': {
|
|
'TableName': pair.table_name or table_name,
|
|
'Key': serialize(pair),
|
|
}
|
|
}
|
|
for pair in key.pairs
|
|
]
|
|
|
|
r = client.transact_get_items(TransactItems=transact_items) # type: ignore
|
|
items = [deserialize(r.get('Item', {})) for r in r.get('Responses', [])]
|
|
|
|
if flatten_top:
|
|
head, *tail = items
|
|
else:
|
|
head, tail = {}, items
|
|
|
|
def _extract_sk_values(pair: KeyPair, obj: dict) -> dict:
|
|
obj = omit((PK, SK), obj)
|
|
sk = pair[SK]
|
|
path_spec = getattr(sk, 'path_spec', None)
|
|
|
|
if path_spec:
|
|
from glom import glom
|
|
|
|
return glom(obj, path_spec)
|
|
return obj
|
|
|
|
def _map_key(pair: KeyPair) -> str:
|
|
pk = pair[PK]
|
|
sk = pair[SK]
|
|
|
|
if pair.rename_key:
|
|
return pair.rename_key
|
|
|
|
if getattr(sk, 'rename_key', None):
|
|
return sk.rename_key
|
|
|
|
if isinstance(sk, SortKey):
|
|
warnings.warn(
|
|
'SortKey.remove_prefix() is deprecated and will be removed '
|
|
'in the future.',
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
return sk.removeprefix(sk.remove_prefix or '')
|
|
|
|
return pk
|
|
|
|
return head | {
|
|
_map_key(pair): _extract_sk_values(pair, obj)
|
|
for pair, obj in zip(pairs, tail)
|
|
if obj
|
|
}
|
|
|
|
def query(
|
|
self,
|
|
key: PartitionKey | KeyPair,
|
|
*,
|
|
expr_attr_name: dict = {},
|
|
expr_attr_values: dict = {},
|
|
start_key: str | None = None,
|
|
filter_expr: str | None = None,
|
|
projection_expr: str | None = None,
|
|
index_forward: bool = False,
|
|
limit: int = LIMIT,
|
|
table_name: str | None = None,
|
|
) -> PaginatedResult:
|
|
"""Query returns all items with that partition key or key pair.
|
|
|
|
Example
|
|
-------
|
|
**Query using a composed partition key**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.query(
|
|
PartitionKey(
|
|
ComposeKey('b3511b5a-cb32-4833-a373-f8223f2088d4', prefix='logs')
|
|
),
|
|
)
|
|
|
|
**Query using a key pair**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.query(
|
|
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', 'emails'),
|
|
)
|
|
|
|
Parameters
|
|
----------
|
|
key: PartitionKey or KeyPair
|
|
Partition key or Key pair used for the query.
|
|
expr_attr_name: dict, optional
|
|
Additional mapping for attribute names.
|
|
expr_attr_values: dict, optional
|
|
Additional mapping for attribute values.
|
|
start_key: str, optional
|
|
Starting key for pagination.
|
|
filter_expr: str, optional
|
|
Filter expression for the query.
|
|
index_forward: bool, optional
|
|
Order of the results; True for ascending order.
|
|
limit: int, optional
|
|
Maximum number of items to return.
|
|
|
|
Returns
|
|
-------
|
|
PaginatedResult
|
|
Dict containing the queried items and the key for the next batch.
|
|
|
|
See Also
|
|
--------
|
|
DynamoDB.Client.query: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
|
"""
|
|
key_cond_expr = (
|
|
'#pk = :pk AND begins_with(#sk, :sk)'
|
|
if isinstance(key, KeyPair)
|
|
else '#pk = :pk'
|
|
)
|
|
|
|
r = self.persistence_layer.query(
|
|
key_cond_expr=key_cond_expr,
|
|
expr_attr_name=key.expr_attr_name() | expr_attr_name,
|
|
expr_attr_values=key.expr_attr_values() | expr_attr_values,
|
|
filter_expr=filter_expr,
|
|
projection_expr=projection_expr,
|
|
index_forward=index_forward,
|
|
limit=limit,
|
|
start_key=_startkey_b64decode(start_key) if start_key else {},
|
|
table_name=table_name,
|
|
)
|
|
|
|
items = r['items']
|
|
last_key = _startkey_b64encode(r['last_key']) if r['last_key'] else None
|
|
|
|
def _removeprefix(
|
|
items: list[dict[str, Any]], /, key: str, prefix: str
|
|
) -> list[dict[str, Any]]:
|
|
return [x | {key: x[key].removeprefix(prefix)} for x in items]
|
|
|
|
match key.get(PK), key.get(SK):
|
|
case ComposeKey(), _: # Remove prefix from Partition Key
|
|
items = _removeprefix(items, PK, key[PK].prefix + key[PK].delimiter)
|
|
case _, PrefixKey(): # Remove prefix from Sort Key
|
|
items = _removeprefix(items, SK, key[SK])
|
|
|
|
return {
|
|
'items': items,
|
|
'last_key': last_key,
|
|
}
|
|
|
|
|
|
def _startkey_b64encode(obj: dict) -> str:
|
|
s = json.dumps(obj)
|
|
b = urlsafe_b64encode(s.encode('utf-8')).decode('utf-8')
|
|
return quote(b)
|
|
|
|
|
|
def _startkey_b64decode(s: str) -> dict:
|
|
b = unquote(s).encode('utf-8')
|
|
s = urlsafe_b64decode(b).decode('utf-8')
|
|
return json.loads(s)
|