update layercake version

This commit is contained in:
2025-05-28 17:52:15 -03:00
parent 42e62ec183
commit 797a325cb0
28 changed files with 692 additions and 566 deletions

View File

@@ -339,10 +339,10 @@ class TransactionCanceledException(Exception):
class TransactOperation:
def __init__(
self,
op: dict,
operation: dict,
exc_cls: type[Exception] | None = None,
) -> None:
self.op = op
self.operation = operation
self.exc_cls = exc_cls
@@ -352,23 +352,27 @@ else:
DynamoDBClient = object
class TransactItems:
class TransactWriter:
def __init__(
self,
table_name: str,
*,
flush_amount: int,
client: DynamoDBClient,
) -> None:
self._table_name = table_name
self._operations: list[TransactOperation] = []
self._items_buffer: list[TransactOperation] = []
self._flush_amount = flush_amount
self._client = client
def __enter__(self) -> Self:
"""Remove operations from previous execution."""
self._operations.clear()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
return False
def __exit__(self, *exc_details) -> None:
# When we exit, we need to keep flushing whatever's left
# until there's nothing left in our items buffer.
while self._items_buffer:
self._flush()
def put(
self,
@@ -386,7 +390,7 @@ class TransactItems:
if not table_name:
table_name = self._table_name
self._operations.append(
self._add_op_and_process(
TransactOperation(
{
'Put': dict(
@@ -406,8 +410,8 @@ class TransactItems:
update_expr: str,
cond_expr: str | None = None,
table_name: str | None = None,
expr_attr_names: dict = {},
expr_attr_values: dict = {},
expr_attr_names: dict | None = None,
expr_attr_values: dict | None = None,
exc_cls: Type[Exception] | None = None,
) -> None:
attrs: dict = {}
@@ -424,7 +428,7 @@ class TransactItems:
if not table_name:
table_name = self._table_name
self._operations.append(
self._add_op_and_process(
TransactOperation(
{
'Update': dict(
@@ -438,41 +442,14 @@ class TransactItems:
)
)
def get(
self,
*,
table_name: str | None = None,
key: dict,
expr_attr_names: str | None = None,
) -> None:
attrs: dict = {}
if expr_attr_names:
attrs['ExpressionAttributeNames'] = expr_attr_names
if not table_name:
table_name = self._table_name
self._operations.append(
TransactOperation(
{
'Get': dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
}
),
)
def delete(
self,
*,
key: dict,
table_name: str | None = None,
cond_expr: str | None = None,
expr_attr_names: dict = {},
expr_attr_values: dict = {},
expr_attr_names: dict | None = None,
expr_attr_values: dict | None = None,
exc_cls: Type[Exception] | None = None,
) -> None:
attrs: dict = {}
@@ -489,7 +466,7 @@ class TransactItems:
if not table_name:
table_name = self._table_name
self._operations.append(
self._add_op_and_process(
TransactOperation(
{
'Delete': dict(
@@ -508,8 +485,8 @@ class TransactItems:
key: dict,
cond_expr: str,
table_name: str | None = None,
expr_attr_names: dict = {},
expr_attr_values: dict = {},
expr_attr_names: dict | None = None,
expr_attr_values: dict | None = None,
exc_cls: Type[Exception] | None = None,
) -> None:
attrs: dict = {'ConditionExpression': cond_expr}
@@ -523,7 +500,7 @@ class TransactItems:
if not table_name:
table_name = self._table_name
self._operations.append(
self._add_op_and_process(
TransactOperation(
{
'ConditionCheck': dict(
@@ -536,13 +513,21 @@ class TransactItems:
)
)
def write_items(self) -> bool:
operations = self._operations.copy()
self._operations.clear()
def _add_op_and_process(self, op: TransactOperation) -> None:
self._items_buffer.append(op)
self._flush_if_needed()
def _flush_if_needed(self) -> None:
if len(self._items_buffer) >= self._flush_amount:
self._flush()
def _flush(self) -> bool:
items_to_send = self._items_buffer[: self._flush_amount]
self._items_buffer = self._items_buffer[self._flush_amount :]
try:
self._client.transact_write_items(
TransactItems=[item.op for item in operations] # type: ignore
TransactItems=[item.operation for item in items_to_send] # type: ignore
)
except ClientError as err:
error_msg = glom(err, 'response.Error.Message', default='')
@@ -553,7 +538,7 @@ class TransactItems:
if 'Message' not in reason:
continue
item = operations[idx]
item = items_to_send[idx]
if item.exc_cls:
raise item.exc_cls(error_msg)
@@ -562,7 +547,7 @@ class TransactItems:
{
'code': reason.get('Code'),
'message': reason.get('Message'),
'operation': item.op,
'operation': item.operation,
}
)
@@ -570,32 +555,11 @@ class TransactItems:
else:
return True
def get_items(self) -> list[dict[str, Any]]:
operations = self._operations.copy()
self._operations.clear()
try:
response = self._client.transact_get_items(
TransactItems=[item.op for item in operations] # type: ignore
)
except ClientError as err:
logger.exception(err)
raise
else:
return [
deserialize(response.get('Item', {}))
for response in response.get('Responses', [])
]
class DynamoDBPersistenceLayer:
def __init__(self, table_name: str, client: DynamoDBClient) -> None:
self._table_name = table_name
self._client = client
@property
def collect(self) -> 'DynamoDBCollection':
return DynamoDBCollection(self)
self.table_name = table_name
self.client = client
def query(
self,
@@ -625,7 +589,7 @@ class DynamoDBPersistenceLayer:
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
"""
attrs: dict = {
'TableName': self._table_name,
'TableName': self.table_name,
'KeyConditionExpression': key_cond_expr,
'ScanIndexForward': index_forward,
}
@@ -646,7 +610,7 @@ class DynamoDBPersistenceLayer:
attrs['Limit'] = limit
try:
response = self._client.query(**attrs)
response = self.client.query(**attrs)
except ClientError as err:
logger.info(attrs)
logger.exception(err)
@@ -665,12 +629,12 @@ class DynamoDBPersistenceLayer:
there will be no Item element in the response.
"""
attrs = {
'TableName': self._table_name,
'TableName': self.table_name,
'Key': serialize(key),
}
try:
response = self._client.get_item(**attrs)
response = self.client.get_item(**attrs)
except ClientError as err:
logger.info(attrs)
logger.exception(err)
@@ -680,7 +644,7 @@ class DynamoDBPersistenceLayer:
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
attrs = {
'TableName': self._table_name,
'TableName': self.table_name,
'Item': serialize(item),
}
@@ -688,7 +652,7 @@ class DynamoDBPersistenceLayer:
attrs['ConditionExpression'] = cond_expr
try:
self._client.put_item(**attrs)
self.client.put_item(**attrs)
except ClientError as err:
logger.info(attrs)
logger.exception(err)
@@ -706,7 +670,7 @@ class DynamoDBPersistenceLayer:
expr_attr_values: dict | None = None,
) -> bool:
attrs: dict = {
'TableName': self._table_name,
'TableName': self.table_name,
'Key': serialize(key),
'UpdateExpression': update_expr,
}
@@ -721,7 +685,7 @@ class DynamoDBPersistenceLayer:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
try:
self._client.update_item(**attrs)
self.client.update_item(**attrs)
except ClientError as err:
logger.info(attrs)
logger.exception(err)
@@ -742,7 +706,7 @@ class DynamoDBPersistenceLayer:
or if it has an expected attribute value.
"""
attrs: dict = {
'TableName': self._table_name,
'TableName': self.table_name,
'Key': serialize(key),
}
@@ -756,7 +720,7 @@ class DynamoDBPersistenceLayer:
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
try:
self._client.delete_item(**attrs)
self.client.delete_item(**attrs)
except ClientError as err:
logger.info(attrs)
logger.exception(err)
@@ -764,8 +728,16 @@ class DynamoDBPersistenceLayer:
else:
return True
def transact_items(self) -> TransactItems:
return TransactItems(table_name=self._table_name, client=self._client)
@property
def collection(self) -> 'DynamoDBCollection':
return DynamoDBCollection(self)
def transact_writer(self, flush_amount: int = 50) -> TransactWriter:
return TransactWriter(
table_name=self.table_name,
client=self.client,
flush_amount=flush_amount,
)
def batch_writer(
self,
@@ -797,8 +769,8 @@ class DynamoDBPersistenceLayer:
DynamoDB.Table.batch_writer https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/table/batch_writer.html#DynamoDB.Table.batch_writer
"""
return BatchWriter(
table_name=table_name or self._table_name,
client=self._client,
table_name=table_name or self.table_name,
client=self.client,
overwrite_by_pkeys=overwrite_by_pkeys,
)
@@ -1033,15 +1005,23 @@ class DynamoDBCollection:
if not key.pairs:
return {}
items = []
sortkeys = key.pairs[1:] if flatten_top else key.pairs
client = self.persistence_layer.client
table_name = self.persistence_layer.table_name
with self.persistence_layer.transact_items() as transact:
# Add a get operation for each key for the transaction
for pair in key.pairs:
transact.get(key=pair)
# Add a get operation for each key for the transaction
transact_items = [
{
'Get': {
'TableName': getattr(pair, 'table_name', table_name),
'Key': serialize(pair),
}
}
for pair in key.pairs
]
items = transact.get_items()
response = client.transact_get_items(TransactItems=transact_items) # type: ignore
items = [deserialize(r.get('Item', {})) for r in response.get('Responses', [])]
if flatten_top:
head, *tail = items