add exception to transact

This commit is contained in:
2025-05-21 19:16:06 -03:00
parent 249116cc76
commit 4673a7b21c
3 changed files with 118 additions and 55 deletions

View File

@@ -1,11 +1,16 @@
import inspect import inspect
import logging import logging
import os
from contextlib import AbstractContextManager from contextlib import AbstractContextManager
from enum import Enum from enum import Enum
from typing import Any, Callable, NamedTuple, Self, Sequence from typing import Any, Callable, NamedTuple, Self, Sequence
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class Status(Enum): class Status(Enum):

View File

@@ -1,4 +1,5 @@
import json import json
import logging
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from base64 import urlsafe_b64decode, urlsafe_b64encode from base64 import urlsafe_b64decode, urlsafe_b64encode
@@ -9,9 +10,9 @@ from typing import TYPE_CHECKING, Any, Self, Type, TypedDict
from urllib.parse import quote, unquote from urllib.parse import quote, unquote
from uuid import UUID from uuid import UUID
from aws_lambda_powertools import Logger
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from glom import glom
from .dateutils import now, timestamp from .dateutils import now, timestamp
from .funcs import omit, pick from .funcs import omit, pick
@@ -21,9 +22,12 @@ PK = os.getenv('DYNAMODB_PARTITION_KEY', 'pk')
SK = os.getenv('DYNAMODB_SORT_KEY', 'sk') SK = os.getenv('DYNAMODB_SORT_KEY', 'sk')
LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25)) LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25))
DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#') DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#')
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
logger = Logger(__name__) logging.basicConfig(level=getattr(logging, LOG_LEVEL))
logger = logging.getLogger(__name__)
serializer = TypeSerializer() serializer = TypeSerializer()
deserializer = TypeDeserializer() deserializer = TypeDeserializer()
@@ -201,9 +205,9 @@ class TransactKey:
sk: tuple[SortKey, ...] = () sk: tuple[SortKey, ...] = ()
def __add__(self, sk: SortKey) -> 'TransactKey': def __add__(self, sk: SortKey) -> 'TransactKey':
if isinstance(sk, SortKey): if not isinstance(sk, SortKey):
return TransactKey(pk=self.pk, sk=self.sk + (sk,))
raise TypeError('Can only add a SortKey to a TransactKey') raise TypeError('Can only add a SortKey to a TransactKey')
return TransactKey(pk=self.pk, sk=self.sk + (sk,))
class Key(ABC, dict): class Key(ABC, dict):
@@ -267,10 +271,33 @@ class KeyPair(Key):
return cls(*pair) return cls(*pair)
class TransactionCanceledReason(TypedDict):
code: str
message: str
operation: dict
class TransactionCanceledException(Exception):
def __init__(self, msg: str = '', reasons: list = []) -> None:
super().__init__(msg)
self.msg = msg
self.reasons: list[TransactionCanceledReason] = reasons
class TransactOperation:
def __init__(
self,
op: dict,
exc_cls: type[Exception] | None = None,
) -> None:
self.op = op
self.exc_cls = exc_cls
class TransactItems: class TransactItems:
def __init__(self, table_name: str) -> None: def __init__(self, table_name: str) -> None:
self.table_name = table_name self.table_name = table_name
self.items: list[dict] = [] self.items: list[TransactOperation] = []
def put( def put(
self, self,
@@ -278,6 +305,7 @@ class TransactItems:
item: dict, item: dict,
table_name: str | None = None, table_name: str | None = None,
cond_expr: str | None = None, cond_expr: str | None = None,
exc_cls: Type[Exception] | None = None,
) -> None: ) -> None:
attrs: dict = {} attrs: dict = {}
@@ -288,13 +316,16 @@ class TransactItems:
table_name = self.table_name table_name = self.table_name
self.items.append( self.items.append(
dict( TransactOperation(
Put=dict( {
'Put': dict(
TableName=table_name, TableName=table_name,
Item=serialize(item), Item=serialize(item),
**attrs, **attrs,
) )
) },
exc_cls,
),
) )
def update( def update(
@@ -322,13 +353,15 @@ class TransactItems:
table_name = self.table_name table_name = self.table_name
self.items.append( self.items.append(
dict( TransactOperation(
Update=dict( {
'Update': dict(
TableName=table_name, TableName=table_name,
Key=serialize(key), Key=serialize(key),
UpdateExpression=update_expr, UpdateExpression=update_expr,
**attrs, **attrs,
) )
}
) )
) )
@@ -348,12 +381,14 @@ class TransactItems:
table_name = self.table_name table_name = self.table_name
self.items.append( self.items.append(
dict( TransactOperation(
Get=dict( {
'Get': dict(
TableName=table_name, TableName=table_name,
Key=serialize(key), Key=serialize(key),
**attrs, **attrs,
) )
}
) )
) )
@@ -381,12 +416,14 @@ class TransactItems:
table_name = self.table_name table_name = self.table_name
self.items.append( self.items.append(
dict( TransactOperation(
Delete=dict( {
'Delete': dict(
TableName=table_name, TableName=table_name,
Key=serialize(key), Key=serialize(key),
**attrs, **attrs,
) )
}
) )
) )
@@ -411,12 +448,14 @@ class TransactItems:
table_name = self.table_name table_name = self.table_name
self.items.append( self.items.append(
dict( TransactOperation(
ConditionCheck=dict( {
'ConditionCheck': dict(
TableName=table_name, TableName=table_name,
Key=serialize(key), Key=serialize(key),
**attrs, **attrs,
) )
}
) )
) )
@@ -589,7 +628,7 @@ class DynamoDBPersistenceLayer:
def transact_get_items(self, transact_items: TransactItems) -> list[dict[str, Any]]: def transact_get_items(self, transact_items: TransactItems) -> list[dict[str, Any]]:
try: try:
response = self.dynamodb_client.transact_get_items( response = self.dynamodb_client.transact_get_items(
TransactItems=transact_items.items # type: ignore TransactItems=[item.op for item in transact_items.items] # type: ignore
) )
except ClientError as err: except ClientError as err:
logger.exception(err) logger.exception(err)
@@ -603,11 +642,30 @@ class DynamoDBPersistenceLayer:
def transact_write_items(self, transact_items: TransactItems) -> bool: def transact_write_items(self, transact_items: TransactItems) -> bool:
try: try:
self.dynamodb_client.transact_write_items( self.dynamodb_client.transact_write_items(
TransactItems=transact_items.items # type: ignore TransactItems=[item.op for item in transact_items.items] # type: ignore
) )
except ClientError as err: except ClientError as err:
logger.exception(err) error_msg = glom(err, 'response.Error.Message', default='')
raise cancellation_reasons = err.response.get('CancellationReasons', [])
reasons = []
for idx, reason in enumerate(cancellation_reasons):
if reason.get('Code') == 'None':
continue
item = transact_items.items[idx]
if item.exc_cls:
raise item.exc_cls(error_msg)
reasons.append(
{
'code': reason.get('Code'),
'message': reason.get('Message'),
'operation': item.op,
}
)
raise TransactionCanceledException(error_msg, reasons)
else: else:
return True return True
@@ -637,7 +695,7 @@ class DynamoDBCollection:
---------- ----------
persistence_layer: DynamoDBPersistenceLayer persistence_layer: DynamoDBPersistenceLayer
The persistence layer instance responsible for direct DynamoDB operations. The persistence layer instance responsible for direct DynamoDB operations.
exception_cls: Type[Exception], optional exc_cls: Type[Exception], optional
The exception class to be raised when a requested item is not found. The exception class to be raised when a requested item is not found.
tz: str, optional tz: str, optional
The timezone identifier used for date/time operations. The timezone identifier used for date/time operations.
@@ -647,11 +705,11 @@ class DynamoDBCollection:
self, self,
persistence_layer: DynamoDBPersistenceLayer, persistence_layer: DynamoDBPersistenceLayer,
/, /,
exception_cls: Type[Exception] = MissingError, exc_cls: Type[Exception] = MissingError,
tz: str = TZ, tz: str = TZ,
) -> None: ) -> None:
self.persistence_layer = persistence_layer self.persistence_layer = persistence_layer
self.exception_cls = exception_cls self.exc_cls = exc_cls
self.tz = tz self.tz = tz
def get_item( def get_item(
@@ -660,9 +718,8 @@ class DynamoDBCollection:
path_spec: str | None = None, path_spec: str | None = None,
/, /,
raise_on_error: bool = True, raise_on_error: bool = True,
exception_cls: Type[Exception] | None = None, exc_cls: Type[Exception] | None = None,
default: Any = None, default: Any = None,
delimiter: str = '#',
) -> Any: ) -> Any:
"""Get an item with the given key. """Get an item with the given key.
@@ -686,12 +743,10 @@ class DynamoDBCollection:
A path specification for nested data extraction. A path specification for nested data extraction.
raise_on_error: bool, optional raise_on_error: bool, optional
If True, raises an exception when the item is not found. If True, raises an exception when the item is not found.
exception_cls: Type[Exception], optional exc_cls: Type[Exception], optional
Exception class to be used if the item is not found. Exception class to be used if the item is not found.
default: Any, optional default: Any, optional
Default value returned if the item is not found. Default value returned if the item is not found.
delimiter: str, optional
Delimiter used in key composition.
Returns Returns
------- -------
@@ -703,7 +758,7 @@ class DynamoDBCollection:
Exception Exception
Raises the provided exception if the item is not found and raise_on_error is True. Raises the provided exception if the item is not found and raise_on_error is True.
""" """
exc_cls = exception_cls or self.exception_cls exc_cls = exc_cls or self.exc_cls
data = self.persistence_layer.get_item(key) data = self.persistence_layer.get_item(key)
if raise_on_error and not data: if raise_on_error and not data:

View File

@@ -3,7 +3,6 @@ from decimal import Decimal
from ipaddress import IPv4Address from ipaddress import IPv4Address
import pytest import pytest
from botocore.exceptions import ClientError
from layercake.dateutils import ttl from layercake.dateutils import ttl
from layercake.dynamodb import ( from layercake.dynamodb import (
@@ -94,6 +93,8 @@ def test_transact_write_items(
dynamodb_seeds, dynamodb_seeds,
dynamodb_persistence_layer: DynamoDBPersistenceLayer, dynamodb_persistence_layer: DynamoDBPersistenceLayer,
): ):
class EmailConflictError(Exception): ...
transact = TransactItems(dynamodb_persistence_layer.table_name) transact = TransactItems(dynamodb_persistence_layer.table_name)
transact.put(item=KeyPair('5OxmMjL-ujoR5IMGegQz', '0')) transact.put(item=KeyPair('5OxmMjL-ujoR5IMGegQz', '0'))
transact.put(item=KeyPair('cpf', '07879819908')) transact.put(item=KeyPair('cpf', '07879819908'))
@@ -107,10 +108,12 @@ def test_transact_write_items(
ComposeKey('sergio@somosbeta.com.br', 'emails'), ComposeKey('sergio@somosbeta.com.br', 'emails'),
), ),
cond_expr='attribute_not_exists(sk)', cond_expr='attribute_not_exists(sk)',
exc_cls=EmailConflictError,
) )
with pytest.raises(ClientError): with pytest.raises(EmailConflictError) as exc:
dynamodb_persistence_layer.transact_write_items(transact) dynamodb_persistence_layer.transact_write_items(transact)
# print(exc.value)
def test_collection_get_item( def test_collection_get_item(
@@ -151,7 +154,7 @@ def test_collection_get_item(
with pytest.raises(NotFoundError): with pytest.raises(NotFoundError):
collect.get_item( collect.get_item(
KeyPair('5OxmMjL-ujoR5IMGegQz', 'notfound'), KeyPair('5OxmMjL-ujoR5IMGegQz', 'notfound'),
exception_cls=NotFoundError, exc_cls=NotFoundError,
) )