add flatten top

This commit is contained in:
2025-04-08 17:33:50 -03:00
parent 6450e5fa7c
commit 5b5e381191
3 changed files with 42 additions and 15 deletions

View File

@@ -709,7 +709,11 @@ class DynamoDBCollection:
expr_attr_values=expr_attr_values, expr_attr_values=expr_attr_values,
) )
def get_items(self, key: TransactKey) -> dict[str, Any]: def get_items(
self,
key: TransactKey,
flatten_top: bool = True,
) -> dict[str, Any]:
"""Get multiple items via a transaction based on the provided TransactKey. """Get multiple items via a transaction based on the provided TransactKey.
Example Example
@@ -726,6 +730,12 @@ class DynamoDBCollection:
A TransactKey instance that contains a partition key and one or more sort keys. A TransactKey instance that contains a partition key and one or more sort keys.
If no sort key is provided, the transaction is skipped. If no sort key is provided, the transaction is skipped.
flatten_top : bool, optional
Determines whether the first nested item in the transaction result should be flattened,
i.e., extracted to serve as the primary item at the top level of the returned dict.
If True, the nested item is promoted to the top level.
The default is True.
Returns Returns
------- -------
dict[str, Any] dict[str, Any]
@@ -737,6 +747,7 @@ class DynamoDBCollection:
return {} return {}
table_name = self.persistence_layer.table_name table_name = self.persistence_layer.table_name
sortkeys = key.sk[1:] if flatten_top else key.sk
transact = TransactItems(table_name) transact = TransactItems(table_name)
# Add a get operation for each sort key for the transaction # Add a get operation for each sort key for the transaction
@@ -746,17 +757,15 @@ class DynamoDBCollection:
table_name=sk.table_name, table_name=sk.table_name,
) )
data, *rest = self.persistence_layer.transact_get_items(transact) items = self.persistence_layer.transact_get_items(transact)
return data | {
k: omit( if flatten_top:
( head, *tail = items
PK, else:
SK, head, tail = {}, items
),
item, return head | {
) k: omit((PK, SK), item) for k, item in zip(sortkeys, tail) if item
for k, item in zip(key.sk[1:], rest)
if item
} }
def query( def query(

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "layercake" name = "layercake"
version = "0.2.0" version = "0.2.1"
description = "Packages shared dependencies to optimize deployment and ensure consistency across functions." description = "Packages shared dependencies to optimize deployment and ensure consistency across functions."
readme = "README.md" readme = "README.md"
authors = [ authors = [

View File

@@ -246,12 +246,11 @@ def test_collection_get_items(
dynamodb_persistence_layer: DynamoDBPersistenceLayer, dynamodb_persistence_layer: DynamoDBPersistenceLayer,
): ):
collect = DynamoDBCollection(dynamodb_persistence_layer) collect = DynamoDBCollection(dynamodb_persistence_layer)
doc = collect.get_items( doc = collect.get_items(
TransactKey('cJtK9SsnJhKPyxESe7g3DG') TransactKey('cJtK9SsnJhKPyxESe7g3DG')
+ SortKey('0') + SortKey('0')
+ SortKey('billing_policy') + SortKey('billing_policy')
+ SortKey('payment_policy') + SortKey('payment_policy'),
) )
assert doc == { assert doc == {
@@ -263,3 +262,22 @@ def test_collection_get_items(
'billing_policy': {'billing_day': Decimal('1'), 'payment_method': 'PIX'}, 'billing_policy': {'billing_day': Decimal('1'), 'payment_method': 'PIX'},
'payment_policy': {'due_days': Decimal('90')}, 'payment_policy': {'due_days': Decimal('90')},
} }
def test_collection_get_items_unflatten(
dynamodb_seeds,
dynamodb_persistence_layer: DynamoDBPersistenceLayer,
):
collect = DynamoDBCollection(dynamodb_persistence_layer)
doc = collect.get_items(
TransactKey('cJtK9SsnJhKPyxESe7g3DG')
+ SortKey('billing_policy')
+ SortKey('payment_policy'),
flatten_top=False,
)
assert doc == {
'billing_policy': {'billing_day': Decimal('1'), 'payment_method': 'PIX'},
'payment_policy': {'due_days': Decimal('90')},
}