diff --git a/layercake/layercake/dynamodb.py b/layercake/layercake/dynamodb.py index 412c18b..a806ef5 100644 --- a/layercake/layercake/dynamodb.py +++ b/layercake/layercake/dynamodb.py @@ -709,7 +709,11 @@ class DynamoDBCollection: expr_attr_values=expr_attr_values, ) - def get_items(self, key: TransactKey) -> dict[str, Any]: + def get_items( + self, + key: TransactKey, + flatten_top: bool = True, + ) -> dict[str, Any]: """Get multiple items via a transaction based on the provided TransactKey. Example @@ -726,6 +730,12 @@ class DynamoDBCollection: A TransactKey instance that contains a partition key and one or more sort keys. If no sort key is provided, the transaction is skipped. + flatten_top : bool, optional + Determines whether the first nested item in the transaction result should be flattened, + i.e., extracted to serve as the primary item at the top level of the returned dict. + If True, the nested item is promoted to the top level. + The default is True. + Returns ------- dict[str, Any] @@ -737,6 +747,7 @@ class DynamoDBCollection: return {} table_name = self.persistence_layer.table_name + sortkeys = key.sk[1:] if flatten_top else key.sk transact = TransactItems(table_name) # Add a get operation for each sort key for the transaction @@ -746,17 +757,15 @@ class DynamoDBCollection: table_name=sk.table_name, ) - data, *rest = self.persistence_layer.transact_get_items(transact) - return data | { - k: omit( - ( - PK, - SK, - ), - item, - ) - for k, item in zip(key.sk[1:], rest) - if item + items = self.persistence_layer.transact_get_items(transact) + + if flatten_top: + head, *tail = items + else: + head, tail = {}, items + + return head | { + k: omit((PK, SK), item) for k, item in zip(sortkeys, tail) if item } def query( diff --git a/layercake/pyproject.toml b/layercake/pyproject.toml index 9c52ef3..862392b 100644 --- a/layercake/pyproject.toml +++ b/layercake/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "layercake" -version = "0.2.0" +version = "0.2.1" description = "Packages shared dependencies to optimize deployment and ensure consistency across functions." readme = "README.md" authors = [ diff --git a/layercake/tests/test_dynamodb.py b/layercake/tests/test_dynamodb.py index 679a162..f986c8e 100644 --- a/layercake/tests/test_dynamodb.py +++ b/layercake/tests/test_dynamodb.py @@ -246,12 +246,11 @@ def test_collection_get_items( dynamodb_persistence_layer: DynamoDBPersistenceLayer, ): collect = DynamoDBCollection(dynamodb_persistence_layer) - doc = collect.get_items( TransactKey('cJtK9SsnJhKPyxESe7g3DG') + SortKey('0') + SortKey('billing_policy') - + SortKey('payment_policy') + + SortKey('payment_policy'), ) assert doc == { @@ -263,3 +262,22 @@ def test_collection_get_items( 'billing_policy': {'billing_day': Decimal('1'), 'payment_method': 'PIX'}, 'payment_policy': {'due_days': Decimal('90')}, } + + +def test_collection_get_items_unflatten( + dynamodb_seeds, + dynamodb_persistence_layer: DynamoDBPersistenceLayer, +): + collect = DynamoDBCollection(dynamodb_persistence_layer) + + doc = collect.get_items( + TransactKey('cJtK9SsnJhKPyxESe7g3DG') + + SortKey('billing_policy') + + SortKey('payment_policy'), + flatten_top=False, + ) + + assert doc == { + 'billing_policy': {'billing_day': Decimal('1'), 'payment_method': 'PIX'}, + 'payment_policy': {'due_days': Decimal('90')}, + }