improve get items

This commit is contained in:
2025-05-22 23:18:59 -03:00
parent 585bcfcc2a
commit 812470aae4
8 changed files with 179 additions and 79 deletions

View File

@@ -14,7 +14,7 @@ from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError
from glom import glom
from .dateutils import now, timestamp
from .dateutils import timestamp
from .funcs import omit, pick
TZ = os.getenv('TZ', 'UTC')
@@ -142,14 +142,13 @@ if TYPE_CHECKING:
----------
sk: str
The sort key value.
table_name: str, optional
Optional name of the table associated with the sort key.
path_spec: str, optional
Optional specification for nested data extraction.
remove_prefix: str, optional
Optional prefix to remove from the key when forming the result dict.
"""
sk: str
table_name: str | None = None
path_spec: str | None = None
remove_prefix: str | None = None
else:
@@ -163,17 +162,16 @@ else:
----------
sk: str
The sort key value.
table_name: str, optional
Optional name of the table associated with the sort key.
path_spec: str, optional
Optional specification for nested data extraction.
remove_prefix: str, optional
Optional prefix to remove from the key when forming the result dict.
"""
def __new__(
cls,
sk: str,
*,
table_name: str | None = None,
path_spec: str | None = None,
remove_prefix: str | None = None,
) -> str:
@@ -183,37 +181,16 @@ else:
self,
sk: str,
*,
table_name: str | None = None,
path_spec: str | None = None,
remove_prefix: str | None = None,
) -> None:
# __init__ is used to store the parameters for later reference.
# For immutable types like str, __init__ cannot change the instance's value.
self.sk = sk
self.table_name = table_name
self.path_spec = path_spec
self.remove_prefix = remove_prefix
@dataclass
class TransactKey:
"""
Example
-------
TransactKey('e9bb7dc6-c7b2-4d34-8931-d298353758ec')
+ SortKey('0')
+ SortKey('tenant')
"""
pk: str
sk: tuple[SortKey, ...] = ()
def __add__(self, sk: SortKey) -> 'TransactKey':
if not isinstance(sk, SortKey):
raise TypeError('Can only add a SortKey to a TransactKey')
return TransactKey(pk=self.pk, sk=self.sk + (sk,))
class Key(ABC, dict):
@abstractmethod
def expr_attr_name(self) -> dict: ...
@@ -274,6 +251,41 @@ class KeyPair(Key):
return cls(*pair)
def __add__(self, other: Self) -> 'KeyChain':
return KeyChain(pairs=(self, other))
@dataclass(frozen=True)
class KeyChain:
pairs: tuple[KeyPair, ...] = ()
def __add__(self, other: KeyPair) -> 'KeyChain':
if not isinstance(other, KeyPair):
raise TypeError('Can only add a KeyPair to a KeyChain')
return KeyChain(pairs=self.pairs + (other,))
@dataclass(frozen=True)
class TransactKey:
"""
Example
-------
TransactKey('e9bb7dc6-c7b2-4d34-8931-d298353758ec')
+ SortKey('0')
+ SortKey('tenant')
"""
pk: str
pairs: tuple[KeyPair, ...] = ()
def __add__(self, other: SortKey) -> 'TransactKey':
if not isinstance(other, SortKey):
raise TypeError('Can only add a SortKey to a TransactKey')
pair = KeyPair(self.pk, other)
return TransactKey(pk=self.pk, pairs=self.pairs + (pair,))
class TransactionCanceledReason(TypedDict):
code: str
@@ -481,6 +493,10 @@ class DynamoDBPersistenceLayer:
self.table_name = table_name
self.dynamodb_client = dynamodb_client
@property
def collect(self) -> 'DynamoDBCollection':
return DynamoDBCollection(self)
def query(
self,
*,
@@ -810,13 +826,10 @@ class DynamoDBCollection:
bool
True if the operation is successful, False otherwise.
"""
now_ = now(self.tz)
if isinstance(ttl, int):
kwargs.update(
{
'ttl': ttl,
'ttl_date': datetime.fromtimestamp(ttl, now_.tzinfo),
}
)
@@ -824,7 +837,6 @@ class DynamoDBCollection:
kwargs.update(
{
'ttl': timestamp(ttl),
'ttl_date': ttl,
}
)
@@ -865,7 +877,7 @@ class DynamoDBCollection:
def get_items(
self,
key: TransactKey,
key: TransactKey | KeyChain,
flatten_top: bool = True,
) -> dict[str, Any]:
"""Get multiple items via a transaction based on the provided TransactKey.
@@ -899,19 +911,16 @@ class DynamoDBCollection:
"""
# If no sort key is provided, the query is skipped
if not key.sk:
if not key.pairs:
return {}
table_name = self.persistence_layer.table_name
sortkeys = key.sk[1:] if flatten_top else key.sk
sortkeys = key.pairs[1:] if flatten_top else key.pairs
transact = TransactItems(table_name)
# Add a get operation for each sort key for the transaction
for sk in key.sk:
transact.get(
key=KeyPair(key.pk, sk),
table_name=sk.table_name,
)
# Add a get operation for each key for the transaction
for pair in key.pairs:
transact.get(key=pair)
items = self.persistence_layer.transact_get_items(transact)
@@ -920,20 +929,29 @@ class DynamoDBCollection:
else:
head, tail = {}, items
def _getin(sk: SortKey, v: dict) -> dict:
def _getin(pair: KeyPair, v: dict) -> dict:
v = omit((PK, SK), v)
sk = pair[SK]
path_spec = getattr(sk, 'path_spec', None)
if sk.path_spec:
if path_spec:
from glom import glom
return glom(v, sk.path_spec)
return glom(v, path_spec)
return v
def _removeprefix(sk: SortKey) -> str:
return sk.removeprefix(sk.remove_prefix) if sk.remove_prefix else sk
def _removeprefix(pair: KeyPair) -> str:
sk = pair[SK]
if not isinstance(sk, SortKey):
return pair[PK]
return sk.removeprefix(sk.remove_prefix or '')
return head | {
_removeprefix(k): _getin(k, item) for k, item in zip(sortkeys, tail) if item
_removeprefix(pair): _getin(pair, item)
for pair, item in zip(sortkeys, tail)
if item
}
def query(