1049 lines
30 KiB
Python
1049 lines
30 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from ipaddress import IPv4Address
|
|
from typing import TYPE_CHECKING, Any, Self, Type, TypedDict
|
|
from urllib.parse import quote, unquote
|
|
from uuid import UUID
|
|
|
|
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
|
|
from botocore.exceptions import ClientError
|
|
from glom import glom
|
|
|
|
from .dateutils import now, timestamp
|
|
from .funcs import omit, pick
|
|
|
|
TZ = os.getenv('TZ', 'UTC')
|
|
PK = os.getenv('DYNAMODB_PARTITION_KEY', 'pk')
|
|
SK = os.getenv('DYNAMODB_SORT_KEY', 'sk')
|
|
LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25))
|
|
DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#')
|
|
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
|
|
|
|
|
|
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
|
|
|
|
logger = logging.getLogger(__name__)
|
|
serializer = TypeSerializer()
|
|
deserializer = TypeDeserializer()
|
|
|
|
|
|
def _serialize_to_basic_types(data: Any) -> str | dict | list:
|
|
match data:
|
|
case datetime():
|
|
return data.isoformat()
|
|
case UUID():
|
|
return str(data)
|
|
case IPv4Address():
|
|
return str(data)
|
|
case tuple():
|
|
return [_serialize_to_basic_types(v) for v in data]
|
|
case dict():
|
|
return {k: _serialize_to_basic_types(v) for k, v in data.items()}
|
|
case _:
|
|
return data
|
|
|
|
|
|
def serialize(data: dict) -> dict:
|
|
return {
|
|
k: serializer.serialize(_serialize_to_basic_types(v)) for k, v in data.items()
|
|
}
|
|
|
|
|
|
def deserialize(data: dict) -> dict:
|
|
return {k: deserializer.deserialize(v) for k, v in data.items()}
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class ComposeKey(str):
|
|
keyparts: str | tuple[str, ...]
|
|
prefix: str | None = None
|
|
delimiter: str = '#'
|
|
else:
|
|
|
|
class ComposeKey(str):
|
|
"""Creates a composite key by joining string parts with a specified delimiter.
|
|
If a prefix is provided, it is added at the beginning of the key parts.
|
|
|
|
Example
|
|
-------
|
|
ComposeKey(('abc', 'xyz'), prefix='examples', delimiter='#')
|
|
"""
|
|
|
|
def __new__(
|
|
cls,
|
|
keyparts: str | tuple[str, ...],
|
|
/,
|
|
prefix: str | None = None,
|
|
*,
|
|
delimiter: str = '#',
|
|
) -> str:
|
|
if isinstance(keyparts, str):
|
|
keyparts = (keyparts,)
|
|
|
|
if prefix:
|
|
keyparts = (prefix,) + keyparts
|
|
|
|
return super().__new__(cls, delimiter.join(keyparts))
|
|
|
|
def __init__(
|
|
self,
|
|
keyparts: str | tuple[str, ...],
|
|
/,
|
|
prefix: str | None = None,
|
|
*,
|
|
delimiter: str = '#',
|
|
) -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.keyparts = keyparts
|
|
self.prefix = prefix
|
|
self.delimiter = delimiter
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class PrefixKey(str):
|
|
prefix: str
|
|
delimiter: str | None = '#'
|
|
|
|
else:
|
|
|
|
class PrefixKey(str):
|
|
def __new__(cls, prefix: str, delimiter: str | None = '#') -> str:
|
|
if not delimiter:
|
|
return super().__new__(cls, prefix)
|
|
|
|
return super().__new__(cls, prefix + delimiter)
|
|
|
|
def __init__(self, prefix: str, delimiter: str = '#') -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.prefix = prefix
|
|
self.delimiter = delimiter
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@dataclass
|
|
class SortKey(str):
|
|
"""
|
|
SortKey encapsulates the sort key value and optionally stores additional attributes
|
|
for nested data extraction.
|
|
|
|
Parameters
|
|
----------
|
|
sk: str
|
|
The sort key value.
|
|
table_name: str, optional
|
|
Optional name of the table associated with the sort key.
|
|
path_spec: str, optional
|
|
Optional specification for nested data extraction.
|
|
"""
|
|
|
|
sk: str
|
|
table_name: str | None = None
|
|
path_spec: str | None = None
|
|
remove_prefix: str | None = None
|
|
else:
|
|
|
|
class SortKey(str):
|
|
"""
|
|
SortKey encapsulates the sort key value and optionally stores additional attributes
|
|
for nested data extraction.
|
|
|
|
Parameters
|
|
----------
|
|
sk: str
|
|
The sort key value.
|
|
table_name: str, optional
|
|
Optional name of the table associated with the sort key.
|
|
path_spec: str, optional
|
|
Optional specification for nested data extraction.
|
|
"""
|
|
|
|
def __new__(
|
|
cls,
|
|
sk: str,
|
|
*,
|
|
table_name: str | None = None,
|
|
path_spec: str | None = None,
|
|
remove_prefix: str | None = None,
|
|
) -> str:
|
|
return super().__new__(cls, sk)
|
|
|
|
def __init__(
|
|
self,
|
|
sk: str,
|
|
*,
|
|
table_name: str | None = None,
|
|
path_spec: str | None = None,
|
|
remove_prefix: str | None = None,
|
|
) -> None:
|
|
# __init__ is used to store the parameters for later reference.
|
|
# For immutable types like str, __init__ cannot change the instance's value.
|
|
self.sk = sk
|
|
self.table_name = table_name
|
|
self.path_spec = path_spec
|
|
self.remove_prefix = remove_prefix
|
|
|
|
|
|
@dataclass
|
|
class TransactKey:
|
|
"""
|
|
Example
|
|
-------
|
|
TransactKey('e9bb7dc6-c7b2-4d34-8931-d298353758ec')
|
|
+ SortKey('0')
|
|
+ SortKey('tenant')
|
|
"""
|
|
|
|
pk: str
|
|
sk: tuple[SortKey, ...] = ()
|
|
|
|
def __add__(self, sk: SortKey) -> 'TransactKey':
|
|
if not isinstance(sk, SortKey):
|
|
raise TypeError('Can only add a SortKey to a TransactKey')
|
|
return TransactKey(pk=self.pk, sk=self.sk + (sk,))
|
|
|
|
|
|
class Key(ABC, dict):
|
|
@abstractmethod
|
|
def expr_attr_name(self) -> dict: ...
|
|
|
|
@abstractmethod
|
|
def expr_attr_values(self) -> dict: ...
|
|
|
|
|
|
class PartitionKey(Key):
|
|
"""Represents a partition key for DynamoDB queries"""
|
|
|
|
def __init__(self, pk: str) -> None:
|
|
super().__init__(**{PK: pk})
|
|
|
|
def __repr__(self) -> str:
|
|
pk = self[PK]
|
|
return f'PartitionKey({pk!r})'
|
|
|
|
def expr_attr_name(self) -> dict:
|
|
return {'#pk': PK}
|
|
|
|
def expr_attr_values(self) -> dict:
|
|
return {':pk': self[PK]}
|
|
|
|
|
|
class KeyPair(Key):
|
|
"""Represents a composite key (partition key and sort key) for DynamoDB queries"""
|
|
|
|
def __init__(self, pk: str, sk: str) -> None:
|
|
super().__init__(**{PK: pk, SK: sk})
|
|
|
|
def __repr__(self) -> str:
|
|
pk, sk = self.values()
|
|
return f'KeyPair({pk!r}, {sk!r})'
|
|
|
|
def expr_attr_name(self) -> dict:
|
|
return {
|
|
'#pk': PK,
|
|
'#sk': SK,
|
|
}
|
|
|
|
def expr_attr_values(self) -> dict:
|
|
return {
|
|
':pk': self[PK],
|
|
':sk': self[SK],
|
|
}
|
|
|
|
@classmethod
|
|
def parse_obj(cls, obj: Any) -> Self | None:
|
|
if not obj:
|
|
return None
|
|
|
|
match obj:
|
|
case dict():
|
|
pair = pick((PK, SK), obj).values() # Gotta keep them in order
|
|
case _:
|
|
pair = obj
|
|
|
|
return cls(*pair)
|
|
|
|
|
|
class TransactionCanceledReason(TypedDict):
|
|
code: str
|
|
message: str
|
|
operation: dict
|
|
|
|
|
|
class TransactionCanceledException(Exception):
|
|
def __init__(self, msg: str = '', reasons: list = []) -> None:
|
|
super().__init__(msg)
|
|
self.msg = msg
|
|
self.reasons: list[TransactionCanceledReason] = reasons
|
|
|
|
|
|
class TransactOperation:
|
|
def __init__(
|
|
self,
|
|
op: dict,
|
|
exc_cls: type[Exception] | None = None,
|
|
) -> None:
|
|
self.op = op
|
|
self.exc_cls = exc_cls
|
|
|
|
|
|
class TransactItems:
|
|
def __init__(self, table_name: str) -> None:
|
|
self.table_name = table_name
|
|
self.items: list[TransactOperation] = []
|
|
|
|
def put(
|
|
self,
|
|
*,
|
|
item: dict,
|
|
table_name: str | None = None,
|
|
cond_expr: str | None = None,
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
TransactOperation(
|
|
{
|
|
'Put': dict(
|
|
TableName=table_name,
|
|
Item=serialize(item),
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
),
|
|
)
|
|
|
|
def update(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
update_expr: str,
|
|
cond_expr: str | None = None,
|
|
table_name: str | None = None,
|
|
expr_attr_names: dict = {},
|
|
expr_attr_values: dict = {},
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
TransactOperation(
|
|
{
|
|
'Update': dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
UpdateExpression=update_expr,
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
)
|
|
)
|
|
|
|
def get(
|
|
self,
|
|
*,
|
|
table_name: str | None = None,
|
|
key: dict,
|
|
expr_attr_names: str | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
TransactOperation(
|
|
{
|
|
'Get': dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
}
|
|
),
|
|
)
|
|
|
|
def delete(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
table_name: str | None = None,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict = {},
|
|
expr_attr_values: dict = {},
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
TransactOperation(
|
|
{
|
|
'Delete': dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
),
|
|
)
|
|
|
|
def condition(
|
|
self,
|
|
*,
|
|
key: dict,
|
|
cond_expr: str,
|
|
table_name: str | None = None,
|
|
expr_attr_names: dict = {},
|
|
expr_attr_values: dict = {},
|
|
exc_cls: Type[Exception] | None = None,
|
|
) -> None:
|
|
attrs: dict = {'ConditionExpression': cond_expr}
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if not table_name:
|
|
table_name = self.table_name
|
|
|
|
self.items.append(
|
|
TransactOperation(
|
|
{
|
|
'ConditionCheck': dict(
|
|
TableName=table_name,
|
|
Key=serialize(key),
|
|
**attrs,
|
|
)
|
|
},
|
|
exc_cls,
|
|
)
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from mypy_boto3_dynamodb.client import DynamoDBClient
|
|
else:
|
|
DynamoDBClient = object
|
|
|
|
|
|
class DynamoDBPersistenceLayer:
|
|
def __init__(self, table_name: str, dynamodb_client: DynamoDBClient) -> None:
|
|
self.table_name = table_name
|
|
self.dynamodb_client = dynamodb_client
|
|
|
|
def query(
|
|
self,
|
|
*,
|
|
key_cond_expr: str,
|
|
expr_attr_name: dict = {},
|
|
expr_attr_values: dict = {},
|
|
start_key: dict = {},
|
|
filter_expr: str | None = None,
|
|
limit: int | None = None,
|
|
index_forward: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""You must provide the name of the partition key attribute and a single value for that attribute.
|
|
|
|
Query returns all items with that partition key value.
|
|
Optionally, you can provide a sort key attribute and use a comparison operator to refine the search results.
|
|
|
|
...
|
|
|
|
A `Query` operation always returns a result set. If no matching items are found, the result set will be empty.
|
|
Queries that do not return results consume the minimum number of read capacity units for that type of read operation.
|
|
|
|
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
|
"""
|
|
attrs: dict = {
|
|
'TableName': self.table_name,
|
|
'KeyConditionExpression': key_cond_expr,
|
|
'ScanIndexForward': index_forward,
|
|
}
|
|
|
|
if expr_attr_name:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_name
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
if start_key:
|
|
attrs['ExclusiveStartKey'] = start_key
|
|
|
|
if filter_expr:
|
|
attrs['FilterExpression'] = filter_expr
|
|
|
|
if limit:
|
|
attrs['Limit'] = limit
|
|
|
|
try:
|
|
response = self.dynamodb_client.query(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return dict(
|
|
items=[deserialize(v) for v in response.get('Items', [])],
|
|
last_key=response.get('LastEvaluatedKey', None),
|
|
)
|
|
|
|
def get_item(self, key: dict) -> dict:
|
|
"""The GetItem operation returns a set of attributes for the item with the given primary key.
|
|
If there is no matching item, GetItem does not return any data and there will be no Item element in the response.
|
|
"""
|
|
attrs = {
|
|
'TableName': self.table_name,
|
|
'Key': serialize(key),
|
|
}
|
|
|
|
try:
|
|
response = self.dynamodb_client.get_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return deserialize(response.get('Item', {}))
|
|
|
|
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
|
|
attrs = {
|
|
'TableName': self.table_name,
|
|
'Item': serialize(item),
|
|
}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
try:
|
|
self.dynamodb_client.put_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
def update_item(
|
|
self,
|
|
key: dict,
|
|
*,
|
|
update_expr: str,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
attrs: dict = {
|
|
'TableName': self.table_name,
|
|
'Key': serialize(key),
|
|
'UpdateExpression': update_expr,
|
|
}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
try:
|
|
self.dynamodb_client.update_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
def delete_item(
|
|
self,
|
|
key: dict,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
"""Deletes a single item in a table by primary key. You can perform a conditional delete operation that deletes
|
|
the item if it exists, or if it has an expected attribute value.
|
|
"""
|
|
attrs: dict = {
|
|
'TableName': self.table_name,
|
|
'Key': serialize(key),
|
|
}
|
|
|
|
if cond_expr:
|
|
attrs['ConditionExpression'] = cond_expr
|
|
|
|
if expr_attr_names:
|
|
attrs['ExpressionAttributeNames'] = expr_attr_names
|
|
|
|
if expr_attr_values:
|
|
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
|
|
|
try:
|
|
self.dynamodb_client.delete_item(**attrs)
|
|
except ClientError as err:
|
|
logger.info(attrs)
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return True
|
|
|
|
def transact_get_items(self, transact_items: TransactItems) -> list[dict[str, Any]]:
|
|
try:
|
|
response = self.dynamodb_client.transact_get_items(
|
|
TransactItems=[item.op for item in transact_items.items] # type: ignore
|
|
)
|
|
except ClientError as err:
|
|
logger.exception(err)
|
|
raise
|
|
else:
|
|
return [
|
|
deserialize(response.get('Item', {}))
|
|
for response in response.get('Responses', [])
|
|
]
|
|
|
|
def transact_write_items(self, transact_items: TransactItems) -> bool:
|
|
try:
|
|
self.dynamodb_client.transact_write_items(
|
|
TransactItems=[item.op for item in transact_items.items] # type: ignore
|
|
)
|
|
except ClientError as err:
|
|
error_msg = glom(err, 'response.Error.Message', default='')
|
|
cancellations = err.response.get('CancellationReasons', [])
|
|
reasons = []
|
|
|
|
for idx, reason in enumerate(cancellations):
|
|
if 'Message' not in reason:
|
|
continue
|
|
|
|
item = transact_items.items[idx]
|
|
|
|
if item.exc_cls:
|
|
raise item.exc_cls(error_msg)
|
|
|
|
reasons.append(
|
|
{
|
|
'code': reason.get('Code'),
|
|
'message': reason.get('Message'),
|
|
'operation': item.op,
|
|
}
|
|
)
|
|
|
|
raise TransactionCanceledException(error_msg, reasons)
|
|
else:
|
|
return True
|
|
|
|
|
|
class MissingError(Exception):
|
|
pass
|
|
|
|
|
|
class PaginatedResult(TypedDict):
|
|
items: list[dict]
|
|
last_key: str | None
|
|
|
|
|
|
class DynamoDBCollection:
|
|
"""
|
|
DynamoDBCollection provides a high-level abstraction for performing common CRUD operations
|
|
and queries on a DynamoDB table. It leverages an underlying persistence layer to handle
|
|
serialization and deserialization of data, key composition, transaction operations, and TTL management.
|
|
|
|
This collection class simplifies interaction with DynamoDB items, allowing users to:
|
|
- Retrieve a single item or multiple items via transactions.
|
|
- Insert (put) items with optional TTL (time-to-live) settings.
|
|
- Delete items based on keys and conditions.
|
|
- Query items using partition keys or composite key pairs with optional filtering and pagination.
|
|
|
|
Parameters
|
|
----------
|
|
persistence_layer: DynamoDBPersistenceLayer
|
|
The persistence layer instance responsible for direct DynamoDB operations.
|
|
exc_cls: Type[Exception], optional
|
|
The exception class to be raised when a requested item is not found.
|
|
tz: str, optional
|
|
The timezone identifier used for date/time operations.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
persistence_layer: DynamoDBPersistenceLayer,
|
|
/,
|
|
exc_cls: Type[Exception] = MissingError,
|
|
tz: str = TZ,
|
|
) -> None:
|
|
self.persistence_layer = persistence_layer
|
|
self.exc_cls = exc_cls
|
|
self.tz = tz
|
|
|
|
def get_item(
|
|
self,
|
|
key: Key,
|
|
path_spec: str | None = None,
|
|
/,
|
|
raise_on_error: bool = True,
|
|
exc_cls: Type[Exception] | None = None,
|
|
default: Any = None,
|
|
) -> Any:
|
|
"""Get an item with the given key.
|
|
|
|
Example
|
|
-------
|
|
**Get an item using a composed sort key**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.get_item(
|
|
KeyPair(
|
|
'b3511b5a-cb32-4833-a373-f8223f2088d4',
|
|
ComposeKey('username@domain.com', prefix='emails'),
|
|
),
|
|
)
|
|
|
|
Parameters
|
|
----------
|
|
key: Key
|
|
Key of the item to be retrieved.
|
|
path_spec: str, optional
|
|
A path specification for nested data extraction.
|
|
raise_on_error: bool, optional
|
|
If True, raises an exception when the item is not found.
|
|
exc_cls: Type[Exception], optional
|
|
Exception class to be used if the item is not found.
|
|
default: Any, optional
|
|
Default value returned if the item is not found.
|
|
|
|
Returns
|
|
-------
|
|
Any
|
|
Data of the retrieved item or the default value if not found.
|
|
|
|
Raises
|
|
------
|
|
Exception
|
|
Raises the provided exception if the item is not found and raise_on_error is True.
|
|
"""
|
|
exc_cls = exc_cls or self.exc_cls
|
|
data = self.persistence_layer.get_item(key)
|
|
|
|
if raise_on_error and not data:
|
|
raise exc_cls(f'Item with {key} not found.')
|
|
|
|
if path_spec and data:
|
|
from glom import glom
|
|
|
|
return glom(data, path_spec, default=default)
|
|
|
|
return data or default
|
|
|
|
def put_item(
|
|
self,
|
|
key: Key,
|
|
*,
|
|
ttl: int | datetime | None = None,
|
|
**kwargs: Any,
|
|
) -> bool:
|
|
"""Creates a new item, or replaces an old item with a new item.
|
|
|
|
Parameters
|
|
----------
|
|
key: Key
|
|
Key for the item to be inserted or updated.
|
|
ttl: int or datetime, optional
|
|
Time-to-live for the item, specified as a timestamp integer or datetime object.
|
|
**kwargs
|
|
Additional data to be stored with the item.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the operation is successful, False otherwise.
|
|
"""
|
|
now_ = now(self.tz)
|
|
|
|
if isinstance(ttl, int):
|
|
kwargs.update(
|
|
{
|
|
'ttl': ttl,
|
|
'ttl_date': datetime.fromtimestamp(ttl, now_.tzinfo),
|
|
}
|
|
)
|
|
|
|
if isinstance(ttl, datetime):
|
|
kwargs.update(
|
|
{
|
|
'ttl': timestamp(ttl),
|
|
'ttl_date': ttl,
|
|
}
|
|
)
|
|
|
|
return self.persistence_layer.put_item(item=key | kwargs)
|
|
|
|
def delete_item(
|
|
self,
|
|
key: Key,
|
|
*,
|
|
cond_expr: str | None = None,
|
|
expr_attr_names: dict | None = None,
|
|
expr_attr_values: dict | None = None,
|
|
) -> bool:
|
|
"""Deletes a single item in a table by key.
|
|
|
|
Parameters
|
|
----------
|
|
key: Key
|
|
Key of the item to be deleted.
|
|
cond_expr: str, optional
|
|
Conditional expression for deletion.
|
|
expr_attr_names: dict, optional
|
|
Mapping of attribute names for the expression.
|
|
expr_attr_values: dict, optional
|
|
Mapping of attribute values for the expression.
|
|
|
|
Returns
|
|
-------
|
|
bool
|
|
True if the item is successfully deleted, False otherwise.
|
|
"""
|
|
return self.persistence_layer.delete_item(
|
|
key=key,
|
|
cond_expr=cond_expr,
|
|
expr_attr_names=expr_attr_names,
|
|
expr_attr_values=expr_attr_values,
|
|
)
|
|
|
|
def get_items(
|
|
self,
|
|
key: TransactKey,
|
|
flatten_top: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""Get multiple items via a transaction based on the provided TransactKey.
|
|
|
|
Example
|
|
-------
|
|
**Get items using chained sort keys**
|
|
|
|
key = (
|
|
TransactKey('b3511b5a-cb32-4833-a373-f8223f2088d4')
|
|
+ SortKey('sk-1') + SortKey('sk-2')
|
|
)
|
|
collect = DynamoDBCollection(...)
|
|
items = collect.get_items(key)
|
|
|
|
Parameters
|
|
----------
|
|
key: TransactKey
|
|
A TransactKey instance that contains a partition key and one or more sort keys.
|
|
If no sort key is provided, the transaction is skipped.
|
|
|
|
flatten_top: bool, optional
|
|
Determines whether the first nested item in the transaction result should be flattened,
|
|
i.e., extracted to serve as the primary item at the top level of the returned dict.
|
|
If True, the nested item is promoted to the top level.
|
|
|
|
Returns
|
|
-------
|
|
dict[str, Any]
|
|
A dict of items retrieved from the transaction.
|
|
"""
|
|
|
|
# If no sort key is provided, the query is skipped
|
|
if not key.sk:
|
|
return {}
|
|
|
|
table_name = self.persistence_layer.table_name
|
|
sortkeys = key.sk[1:] if flatten_top else key.sk
|
|
transact = TransactItems(table_name)
|
|
|
|
# Add a get operation for each sort key for the transaction
|
|
for sk in key.sk:
|
|
transact.get(
|
|
key=KeyPair(key.pk, sk),
|
|
table_name=sk.table_name,
|
|
)
|
|
|
|
items = self.persistence_layer.transact_get_items(transact)
|
|
|
|
if flatten_top:
|
|
head, *tail = items
|
|
else:
|
|
head, tail = {}, items
|
|
|
|
def _getin(sk: SortKey, v: dict) -> dict:
|
|
v = omit((PK, SK), v)
|
|
|
|
if sk.path_spec:
|
|
from glom import glom
|
|
|
|
return glom(v, sk.path_spec)
|
|
return v
|
|
|
|
def _removeprefix(sk: SortKey) -> str:
|
|
return sk.removeprefix(sk.remove_prefix) if sk.remove_prefix else sk
|
|
|
|
return head | {
|
|
_removeprefix(k): _getin(k, item) for k, item in zip(sortkeys, tail) if item
|
|
}
|
|
|
|
def query(
|
|
self,
|
|
key: PartitionKey | KeyPair,
|
|
*,
|
|
expr_attr_name: dict = {},
|
|
expr_attr_values: dict = {},
|
|
start_key: str | None = None,
|
|
filter_expr: str | None = None,
|
|
index_forward: bool = False,
|
|
limit: int = LIMIT,
|
|
) -> PaginatedResult:
|
|
"""Query returns all items with that partition key or key pair.
|
|
|
|
Example
|
|
-------
|
|
**Query using a composed partition key**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.query(
|
|
PartitionKey(
|
|
ComposeKey('b3511b5a-cb32-4833-a373-f8223f2088d4', prefix='logs')
|
|
),
|
|
)
|
|
|
|
**Query using a key pair**
|
|
|
|
collect = DynamoDBCollection(...)
|
|
collect.query(
|
|
KeyPair('b3511b5a-cb32-4833-a373-f8223f2088d4', 'emails'),
|
|
)
|
|
|
|
Parameters
|
|
----------
|
|
key: PartitionKey or KeyPair
|
|
Partition key or Key pair used for the query.
|
|
expr_attr_name: dict, optional
|
|
Additional mapping for attribute names.
|
|
expr_attr_values: dict, optional
|
|
Additional mapping for attribute values.
|
|
start_key: str, optional
|
|
Starting key for pagination.
|
|
filter_expr: str, optional
|
|
Filter expression for the query.
|
|
index_forward: bool, optional
|
|
Order of the results; True for ascending order.
|
|
limit: int, optional
|
|
Maximum number of items to return.
|
|
|
|
Returns
|
|
-------
|
|
PaginatedResult
|
|
Dict containing the queried items and the key for the next batch.
|
|
|
|
See Also
|
|
--------
|
|
DynamoDB.Client.query: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
|
"""
|
|
key_cond_expr = (
|
|
'#pk = :pk AND begins_with(#sk, :sk)'
|
|
if isinstance(key, KeyPair)
|
|
else '#pk = :pk'
|
|
)
|
|
|
|
response = self.persistence_layer.query(
|
|
key_cond_expr=key_cond_expr,
|
|
expr_attr_name=key.expr_attr_name() | expr_attr_name,
|
|
expr_attr_values=key.expr_attr_values() | expr_attr_values,
|
|
filter_expr=filter_expr,
|
|
index_forward=index_forward,
|
|
limit=limit,
|
|
start_key=_startkey_b64decode(start_key) if start_key else {},
|
|
)
|
|
|
|
items = response['items']
|
|
last_key = (
|
|
_startkey_b64encode(response['last_key']) if response['last_key'] else None
|
|
)
|
|
|
|
match key.get(PK), key.get(SK):
|
|
case ComposeKey(), _: # Remove prefix from Partition Key
|
|
items = _removeprefix(items, PK, key[PK].prefix + key[PK].delimiter)
|
|
case _, PrefixKey(): # Remove prefix from Sort Key
|
|
items = _removeprefix(items, SK, key[SK])
|
|
|
|
return {
|
|
'items': items,
|
|
'last_key': last_key,
|
|
}
|
|
|
|
|
|
def _removeprefix(
|
|
items: list[dict[str, Any]],
|
|
/,
|
|
key: str,
|
|
prefix: str,
|
|
) -> list[dict[str, Any]]:
|
|
"""Remove the given prefix from the value associated with key in each item."""
|
|
return [x | {key: x[key].removeprefix(prefix)} for x in items]
|
|
|
|
|
|
def _startkey_b64encode(obj: dict) -> str:
|
|
s = json.dumps(obj)
|
|
b = urlsafe_b64encode(s.encode('utf-8')).decode('utf-8')
|
|
return quote(b)
|
|
|
|
|
|
def _startkey_b64decode(s: str) -> dict:
|
|
b = unquote(s).encode('utf-8')
|
|
s = urlsafe_b64decode(b).decode('utf-8')
|
|
return json.loads(s)
|