update
This commit is contained in:
@@ -346,10 +346,29 @@ class TransactOperation:
|
||||
self.exc_cls = exc_cls
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_dynamodb.client import DynamoDBClient
|
||||
else:
|
||||
DynamoDBClient = object
|
||||
|
||||
|
||||
class TransactItems:
|
||||
def __init__(self, table_name: str) -> None:
|
||||
self.table_name = table_name
|
||||
self.items: list[TransactOperation] = []
|
||||
def __init__(
|
||||
self,
|
||||
table_name: str,
|
||||
client: DynamoDBClient,
|
||||
) -> None:
|
||||
self._table_name = table_name
|
||||
self._operations: list[TransactOperation] = []
|
||||
self._client = client
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
"""Remove operations from previous execution."""
|
||||
self._operations.clear()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
|
||||
return False
|
||||
|
||||
def put(
|
||||
self,
|
||||
@@ -365,9 +384,9 @@ class TransactItems:
|
||||
attrs['ConditionExpression'] = cond_expr
|
||||
|
||||
if not table_name:
|
||||
table_name = self.table_name
|
||||
table_name = self._table_name
|
||||
|
||||
self.items.append(
|
||||
self._operations.append(
|
||||
TransactOperation(
|
||||
{
|
||||
'Put': dict(
|
||||
@@ -403,9 +422,9 @@ class TransactItems:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
if not table_name:
|
||||
table_name = self.table_name
|
||||
table_name = self._table_name
|
||||
|
||||
self.items.append(
|
||||
self._operations.append(
|
||||
TransactOperation(
|
||||
{
|
||||
'Update': dict(
|
||||
@@ -432,9 +451,9 @@ class TransactItems:
|
||||
attrs['ExpressionAttributeNames'] = expr_attr_names
|
||||
|
||||
if not table_name:
|
||||
table_name = self.table_name
|
||||
table_name = self._table_name
|
||||
|
||||
self.items.append(
|
||||
self._operations.append(
|
||||
TransactOperation(
|
||||
{
|
||||
'Get': dict(
|
||||
@@ -468,9 +487,9 @@ class TransactItems:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
if not table_name:
|
||||
table_name = self.table_name
|
||||
table_name = self._table_name
|
||||
|
||||
self.items.append(
|
||||
self._operations.append(
|
||||
TransactOperation(
|
||||
{
|
||||
'Delete': dict(
|
||||
@@ -502,9 +521,9 @@ class TransactItems:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
if not table_name:
|
||||
table_name = self.table_name
|
||||
table_name = self._table_name
|
||||
|
||||
self.items.append(
|
||||
self._operations.append(
|
||||
TransactOperation(
|
||||
{
|
||||
'ConditionCheck': dict(
|
||||
@@ -517,17 +536,62 @@ class TransactItems:
|
||||
)
|
||||
)
|
||||
|
||||
def write_items(self) -> bool:
|
||||
operations = self._operations.copy()
|
||||
self._operations.clear()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mypy_boto3_dynamodb.client import DynamoDBClient
|
||||
else:
|
||||
DynamoDBClient = object
|
||||
try:
|
||||
self._client.transact_write_items(
|
||||
TransactItems=[item.op for item in operations] # type: ignore
|
||||
)
|
||||
except ClientError as err:
|
||||
error_msg = glom(err, 'response.Error.Message', default='')
|
||||
cancellations = err.response.get('CancellationReasons', [])
|
||||
reasons = []
|
||||
|
||||
for idx, reason in enumerate(cancellations):
|
||||
if 'Message' not in reason:
|
||||
continue
|
||||
|
||||
item = operations[idx]
|
||||
|
||||
if item.exc_cls:
|
||||
raise item.exc_cls(error_msg)
|
||||
|
||||
reasons.append(
|
||||
{
|
||||
'code': reason.get('Code'),
|
||||
'message': reason.get('Message'),
|
||||
'operation': item.op,
|
||||
}
|
||||
)
|
||||
|
||||
raise TransactionCanceledException(error_msg, reasons)
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_items(self) -> list[dict[str, Any]]:
|
||||
operations = self._operations.copy()
|
||||
self._operations.clear()
|
||||
|
||||
try:
|
||||
response = self._client.transact_get_items(
|
||||
TransactItems=[item.op for item in operations] # type: ignore
|
||||
)
|
||||
except ClientError as err:
|
||||
logger.exception(err)
|
||||
raise
|
||||
else:
|
||||
return [
|
||||
deserialize(response.get('Item', {}))
|
||||
for response in response.get('Responses', [])
|
||||
]
|
||||
|
||||
|
||||
class DynamoDBPersistenceLayer:
|
||||
def __init__(self, table_name: str, dynamodb_client: DynamoDBClient) -> None:
|
||||
self.table_name = table_name
|
||||
self.dynamodb_client = dynamodb_client
|
||||
def __init__(self, table_name: str, client: DynamoDBClient) -> None:
|
||||
self._table_name = table_name
|
||||
self._client = client
|
||||
|
||||
@property
|
||||
def collect(self) -> 'DynamoDBCollection':
|
||||
@@ -561,7 +625,7 @@ class DynamoDBPersistenceLayer:
|
||||
- https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/client/query.html
|
||||
"""
|
||||
attrs: dict = {
|
||||
'TableName': self.table_name,
|
||||
'TableName': self._table_name,
|
||||
'KeyConditionExpression': key_cond_expr,
|
||||
'ScanIndexForward': index_forward,
|
||||
}
|
||||
@@ -582,7 +646,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['Limit'] = limit
|
||||
|
||||
try:
|
||||
response = self.dynamodb_client.query(**attrs)
|
||||
response = self._client.query(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -601,12 +665,12 @@ class DynamoDBPersistenceLayer:
|
||||
there will be no Item element in the response.
|
||||
"""
|
||||
attrs = {
|
||||
'TableName': self.table_name,
|
||||
'TableName': self._table_name,
|
||||
'Key': serialize(key),
|
||||
}
|
||||
|
||||
try:
|
||||
response = self.dynamodb_client.get_item(**attrs)
|
||||
response = self._client.get_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -616,7 +680,7 @@ class DynamoDBPersistenceLayer:
|
||||
|
||||
def put_item(self, item: dict, *, cond_expr: str | None = None) -> bool:
|
||||
attrs = {
|
||||
'TableName': self.table_name,
|
||||
'TableName': self._table_name,
|
||||
'Item': serialize(item),
|
||||
}
|
||||
|
||||
@@ -624,7 +688,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['ConditionExpression'] = cond_expr
|
||||
|
||||
try:
|
||||
self.dynamodb_client.put_item(**attrs)
|
||||
self._client.put_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -642,7 +706,7 @@ class DynamoDBPersistenceLayer:
|
||||
expr_attr_values: dict | None = None,
|
||||
) -> bool:
|
||||
attrs: dict = {
|
||||
'TableName': self.table_name,
|
||||
'TableName': self._table_name,
|
||||
'Key': serialize(key),
|
||||
'UpdateExpression': update_expr,
|
||||
}
|
||||
@@ -657,7 +721,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
try:
|
||||
self.dynamodb_client.update_item(**attrs)
|
||||
self._client.update_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -678,7 +742,7 @@ class DynamoDBPersistenceLayer:
|
||||
or if it has an expected attribute value.
|
||||
"""
|
||||
attrs: dict = {
|
||||
'TableName': self.table_name,
|
||||
'TableName': self._table_name,
|
||||
'Key': serialize(key),
|
||||
}
|
||||
|
||||
@@ -692,7 +756,7 @@ class DynamoDBPersistenceLayer:
|
||||
attrs['ExpressionAttributeValues'] = serialize(expr_attr_values)
|
||||
|
||||
try:
|
||||
self.dynamodb_client.delete_item(**attrs)
|
||||
self._client.delete_item(**attrs)
|
||||
except ClientError as err:
|
||||
logger.info(attrs)
|
||||
logger.exception(err)
|
||||
@@ -700,50 +764,8 @@ class DynamoDBPersistenceLayer:
|
||||
else:
|
||||
return True
|
||||
|
||||
def transact_get_items(self, transact_items: TransactItems) -> list[dict[str, Any]]:
|
||||
try:
|
||||
response = self.dynamodb_client.transact_get_items(
|
||||
TransactItems=[item.op for item in transact_items.items] # type: ignore
|
||||
)
|
||||
except ClientError as err:
|
||||
logger.exception(err)
|
||||
raise
|
||||
else:
|
||||
return [
|
||||
deserialize(response.get('Item', {}))
|
||||
for response in response.get('Responses', [])
|
||||
]
|
||||
|
||||
def transact_write_items(self, transact_items: TransactItems) -> bool:
|
||||
try:
|
||||
self.dynamodb_client.transact_write_items(
|
||||
TransactItems=[item.op for item in transact_items.items] # type: ignore
|
||||
)
|
||||
except ClientError as err:
|
||||
error_msg = glom(err, 'response.Error.Message', default='')
|
||||
cancellations = err.response.get('CancellationReasons', [])
|
||||
reasons = []
|
||||
|
||||
for idx, reason in enumerate(cancellations):
|
||||
if 'Message' not in reason:
|
||||
continue
|
||||
|
||||
item = transact_items.items[idx]
|
||||
|
||||
if item.exc_cls:
|
||||
raise item.exc_cls(error_msg)
|
||||
|
||||
reasons.append(
|
||||
{
|
||||
'code': reason.get('Code'),
|
||||
'message': reason.get('Message'),
|
||||
'operation': item.op,
|
||||
}
|
||||
)
|
||||
|
||||
raise TransactionCanceledException(error_msg, reasons)
|
||||
else:
|
||||
return True
|
||||
def transact_items(self) -> TransactItems:
|
||||
return TransactItems(table_name=self._table_name, client=self._client)
|
||||
|
||||
def batch_writer(
|
||||
self,
|
||||
@@ -775,8 +797,8 @@ class DynamoDBPersistenceLayer:
|
||||
DynamoDB.Table.batch_writer https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb/table/batch_writer.html#DynamoDB.Table.batch_writer
|
||||
"""
|
||||
return BatchWriter(
|
||||
table_name=table_name or self.table_name,
|
||||
client=self.dynamodb_client,
|
||||
table_name=table_name or self._table_name,
|
||||
client=self._client,
|
||||
overwrite_by_pkeys=overwrite_by_pkeys,
|
||||
)
|
||||
|
||||
@@ -1011,15 +1033,15 @@ class DynamoDBCollection:
|
||||
if not key.pairs:
|
||||
return {}
|
||||
|
||||
table_name = self.persistence_layer.table_name
|
||||
items = []
|
||||
sortkeys = key.pairs[1:] if flatten_top else key.pairs
|
||||
transact = TransactItems(table_name)
|
||||
|
||||
# Add a get operation for each key for the transaction
|
||||
for pair in key.pairs:
|
||||
transact.get(key=pair)
|
||||
with self.persistence_layer.transact_items() as transact:
|
||||
# Add a get operation for each key for the transaction
|
||||
for pair in key.pairs:
|
||||
transact.get(key=pair)
|
||||
|
||||
items = self.persistence_layer.transact_get_items(transact)
|
||||
items = transact.get_items()
|
||||
|
||||
if flatten_top:
|
||||
head, *tail = items
|
||||
|
||||
Reference in New Issue
Block a user