update layercake version
This commit is contained in:
@@ -339,10 +339,10 @@ class TransactionCanceledException(Exception):
|
||||
class TransactOperation:
|
||||
def __init__(
|
||||
self,
|
||||
op: dict,
|
||||
operation: dict,
|
||||
exc_cls: type[Exception] | None = None,
|
||||
) -> None:
|
||||
self.op = op
|
||||
self.operation = operation
|
||||
self.exc_cls = exc_cls
|
||||
|
||||
|
||||
@@ -352,23 +352,27 @@ else:
|
||||
DynamoDBClient = object
|
||||
|
||||
|
||||
class TransactItems:
|
||||
class TransactWriter:
|
||||
def __init__(
|
||||
self,
|
||||
table_name: str,
|
||||
*,
|
||||
flush_amount: int,
|
||||
client: DynamoDBClient,
|
||||
) -> None:
|
||||
self._table_name = table_name
|
||||
self._operations: list[TransactOperation] = []
|
||||
self._items_buffer: list[TransactOperation] = []
|
||||
self._flush_amount = flush_amount
|
||||
self._client = client
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Remove operations from previous execution."""
|
||||
self._operations.clear()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||
return False
|
||||
def __exit__(self, *exc_details) -> None:
|
||||
# When we exit, we need to keep flushing whatever's left
|
||||
# until there's nothing left in our items buffer.
|
||||
while self._items_buffer:
|
||||
self._flush()
|
||||
|
||||
def put(
|
||||
self,
|
||||
@@ -386,7 +390,7 @@ class TransactItems:
|
||||
if not table_name:
|
||||
table_name = self._table_name
|
||||
|
||||
self._operations.append(
|
||||
self._add_op_and_process(
|
||||
TransactOperation(
|
||||
{
|
||||
'Put': dict(
|
||||
@@ -406,8 +410,8 @@ class TransactItems:
|
||||
update_expr: str,
|
||||
cond_expr: str | None = None,
|
||||
table_name: str | None = None,
|
||||
expr_attr_names: dict = {},
|
||||
expr_attr_values: dict = {},
|
||||
expr_attr_names: dict | None = None,
|
||||
expr_attr_values: dict | None = None,
|
||||
exc_cls: Type[Exception] | None = None,
|
||||
) -> None:
|
||||
attrs: dict = {}
|
||||
@@ -424,7 +428,7 @@ class TransactItems:
|
||||
if not table_name:
|
||||
table_name = self._table_name
|
||||
|
||||
self._operations.append(
|
||||
self._add_op_and_process(
|
||||
TransactOperation(
|
||||
{
|
||||
'Update': dict(
|
||||
@@ -438,41 +442,14 @@ class TransactItems:
|
||||
)
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
*,
|
||||
table_name: str | None = None,
|
||||
key: dict,
|
||||
expr_attr_names: str | None = None,
|
||||
) -> None:
|
||||
attrs: dict = {}
|
||||
|
||||
if expr_attr_names:
|
||||
attrs['ExpressionAttributeNames'] = expr_attr_names
|
||||
|
||||
if not table_name:
|
||||
table_name = self._table_name
|
||||
|
||||
self._operations.append(
|
||||
TransactOperation(
|
||||
{
|
||||
'Get': dict(
|
||||
TableName=table_name,
|
||||
Key=serialize(key),
|
||||
**attrs,
|
||||
)
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
*,
|
||||
key: dict,
|
||||
table_name: str | None = None,
|
||||
cond_expr: str | None = None,
|
||||
expr_attr_names: dict = {},
|
||||
expr_attr_values: dict = {},
|
||||
expr_attr_names: dict | None = None,
|
||||
expr_attr_values: dict | None = None,
|
||||
exc_cls: Type[Exception] | None = None,
|
||||
) -> None:
|
||||
attrs: dict = {}
|
||||
@@ -489,7 +466,7 @@ class TransactItems:
|
||||
if not table_name:
|
||||
table_name = self._table_name
|
||||
|
||||
self._operations.append(
|
||||
self._add_op_and_process(
|
||||
TransactOperation(
|
||||
{
|
||||
'Delete': dict(
|
||||
@@ -508,8 +485,8 @@ class TransactItems:
|
||||
key: dict,
|
||||
cond_expr: str,
|
||||
table_name: str | None = None,
|
||||
expr_attr_names: dict = {},
|
||||
expr_attr_values: dict = {},
|
||||
expr_attr_names: dict | None = None,
|
||||
expr_attr_values: dict | None = None,
|
||||
exc_cls: Type[Exception] | None = None,
|
||||
) -> None:
|
||||
attrs: dict = {'ConditionExpression': cond_expr}
|
||||
@@ -523,7 +500,7 @@ class TransactItems:
|
||||
if not table_name:
|
||||
table_name = self._table_name
|
||||
|
||||
self._operations.append(
|
||||
self._add_op_and_process(
|
||||
TransactOperation(
|
||||
{
|
||||
'ConditionCheck': dict(
|
||||
@@ -536,13 +513,21 @@ class TransactItems:
|
||||
)
|
||||
)
|
||||
|
||||
def write_items(self) -> bool:
|
||||
operations = self._operations.copy()
|
||||
self._operations.clear()
|
||||
def _add_op_and_process(self, op: TransactOperation) -> None:
|
||||
self._items_buffer.append(op)
|
||||
self._flush_if_needed()
|
||||
|
||||
def _flush_if_needed(self) -> None:
|
||||
if len(self._items_buffer) >= self._flush_amount:
|
||||
self._flush()
|
||||
|
||||
def _flush(self) -> bool:
|
||||
items_to_send = self._items_buffer[: self._flush_amount]
|
||||
self._items_buffer = self._items_buffer[self._flush_amount :]
|
||||
|
||||
try:
|
||||
self._client.transact_write_items(
|
||||
TransactItems=[item.op for item in operations] # type: ignore
|
||||
TransactItems=[item.operation for item in items_to_send] # type: ignore
|
||||
)
|
||||
except ClientError as err:
|
||||
error_msg = glom(err, 'response.Error.Message', default='')
|
||||
@@ -553,7 +538,7 @@ class TransactItems:
|
||||
if 'Message' not in reason:
|
||||
continue
|
||||
|
||||
item = operations[idx]
|
||||
item = items_to_send[idx]
|
||||
|
||||
if item.exc_cls:
|
||||
raise item.exc_cls(error_msg)
|
||||
@@ -562,7 +547,7 @@ class TransactItems:
|
||||
{
|
||||
'code': reason.get('Code'),
|
||||
'message': reason.get('Message'),
|
||||
'operation': item.op,
|
||||
'operation': item.operation,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -570,32 +555,11 @@ class TransactItems:
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_items(self) -> list[dict[str, Any]]:
|
||||
operations = self._operations.copy()
|
||||
self._operations.clear()
|
||||
|
||||
try:
|
||||
response = self._client.transact_get_items(
|
||||
TransactItems=[item.op for item in operations] # type: ignore
|
||||
)
|
||||
except ClientError as err:
|
||||
logger.exception(err)
|
||||
raise
|
||||
else:
|
||||
return [
|
||||
deserialize(response.get('Item', {}))
|
||||
for response in response.get('Responses', [])
|
||||
]
|
||||
|
||||
|
||||
class DynamoDBPersistenceLayer:
|
||||
def __init__(self, table_name: str, client: DynamoDBClient) -> None:
|
||||
self._table_name = table_name
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def collect(self) -> 'DynamoDBCollection':
|
||||
return DynamoDBCollection(self)
|
||||
self.table_name = table_name
|
||||
self.client = client
|
||||
|
||||
def query(
|
||||
self,
|
||||
@@ -625,7 +589,7 @@ class DynamoDBPersistenceLayer:
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
||||
"""
|
||||
attrs: dict = {
|
||||
'TableName': self._table_name,
|
||||
'TableName': self.table_name,
|
||||
'KeyConditionExpression': key_cond_expr,
|
||||
'ScanIndexForward': index_forward,
|
||||
}
|
||||
@@ -646,7 +610,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['Limit'] = limit
|
||||
|
||||
try:
|
||||
response = self._client.query(**attrs)
|
||||
response = self.client.query(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -665,12 +629,12 @@ class DynamoDBPersistenceLayer:
|
||||
there will be no Item element in the response.
|
||||
"""
|
||||
attrs = {
|
||||
'TableName': self._table_name,
|
||||
'TableName': self.table_name,
|
||||
'Key': serialize(key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self._client.get_item(**attrs)
|
||||
response = self.client.get_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -680,7 +644,7 @@ class DynamoDBPersistenceLayer:
|
||||
|
||||
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
|
||||
attrs = {
|
||||
'TableName': self._table_name,
|
||||
'TableName': self.table_name,
|
||||
'Item': serialize(item),
|
||||
}
|
||||
|
||||
@@ -688,7 +652,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['ConditionExpression'] = cond_expr
|
||||
|
||||
try:
|
||||
self._client.put_item(**attrs)
|
||||
self.client.put_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -706,7 +670,7 @@ class DynamoDBPersistenceLayer:
|
||||
expr_attr_values: dict | None = None,
|
||||
) -> bool:
|
||||
attrs: dict = {
|
||||
'TableName': self._table_name,
|
||||
'TableName': self.table_name,
|
||||
'Key': serialize(key),
|
||||
'UpdateExpression': update_expr,
|
||||
}
|
||||
@@ -721,7 +685,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
try:
|
||||
self._client.update_item(**attrs)
|
||||
self.client.update_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -742,7 +706,7 @@ class DynamoDBPersistenceLayer:
|
||||
or if it has an expected attribute value.
|
||||
"""
|
||||
attrs: dict = {
|
||||
'TableName': self._table_name,
|
||||
'TableName': self.table_name,
|
||||
'Key': serialize(key),
|
||||
}
|
||||
|
||||
@@ -756,7 +720,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
try:
|
||||
self._client.delete_item(**attrs)
|
||||
self.client.delete_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -764,8 +728,16 @@ class DynamoDBPersistenceLayer:
|
||||
else:
|
||||
return True
|
||||
|
||||
def transact_items(self) -> TransactItems:
|
||||
return TransactItems(table_name=self._table_name, client=self._client)
|
||||
@property
|
||||
def collection(self) -> 'DynamoDBCollection':
|
||||
return DynamoDBCollection(self)
|
||||
|
||||
def transact_writer(self, flush_amount: int = 50) -> TransactWriter:
|
||||
return TransactWriter(
|
||||
table_name=self.table_name,
|
||||
client=self.client,
|
||||
flush_amount=flush_amount,
|
||||
)
|
||||
|
||||
def batch_writer(
|
||||
self,
|
||||
@@ -797,8 +769,8 @@ class DynamoDBPersistenceLayer:
|
||||
DynamoDB.Table.batch_writer https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/table/batch_writer.html#DynamoDB.Table.batch_writer
|
||||
"""
|
||||
return BatchWriter(
|
||||
table_name=table_name or self._table_name,
|
||||
client=self._client,
|
||||
table_name=table_name or self.table_name,
|
||||
client=self.client,
|
||||
overwrite_by_pkeys=overwrite_by_pkeys,
|
||||
)
|
||||
|
||||
@@ -1033,15 +1005,23 @@ class DynamoDBCollection:
|
||||
if not key.pairs:
|
||||
return {}
|
||||
|
||||
items = []
|
||||
sortkeys = key.pairs[1:] if flatten_top else key.pairs
|
||||
client = self.persistence_layer.client
|
||||
table_name = self.persistence_layer.table_name
|
||||
|
||||
with self.persistence_layer.transact_items() as transact:
|
||||
# Add a get operation for each key for the transaction
|
||||
for pair in key.pairs:
|
||||
transact.get(key=pair)
|
||||
# Add a get operation for each key for the transaction
|
||||
transact_items = [
|
||||
{
|
||||
'Get': {
|
||||
'TableName': getattr(pair, 'table_name', table_name),
|
||||
'Key': serialize(pair),
|
||||
}
|
||||
}
|
||||
for pair in key.pairs
|
||||
]
|
||||
|
||||
items = transact.get_items()
|
||||
response = client.transact_get_items(TransactItems=transact_items) # type: ignore
|
||||
items = [deserialize(r.get('Item', {})) for r in response.get('Responses', [])]
|
||||
|
||||
if flatten_top:
|
||||
head, *tail = items
|
||||
|
||||
Reference in New Issue
Block a user