add exception to transact

This commit is contained in:
2025-05-21 19:16:06 -03:00
parent 249116cc76
commit 4673a7b21c
3 changed files with 118 additions and 55 deletions

View File

@@ -1,11 +1,16 @@
import inspect
import logging
import os
from contextlib import AbstractContextManager
from enum import Enum
from typing import Any, Callable, NamedTuple, Self, Sequence
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class Status(Enum):

View File

@@ -1,4 +1,5 @@
import json
import logging
import os
from abc import ABC, abstractmethod
from base64 import urlsafe_b64decode, urlsafe_b64encode
@@ -9,9 +10,9 @@ from typing import TYPE_CHECKING, Any, Self, Type, TypedDict
from urllib.parse import quote, unquote
from uuid import UUID
from aws_lambda_powertools import Logger
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from botocore.exceptions import ClientError
from glom import glom
from .dateutils import now, timestamp
from .funcs import omit, pick
@@ -21,9 +22,12 @@ PK = os.getenv('DYNAMODB_PARTITION_KEY', 'pk')
SK = os.getenv('DYNAMODB_SORT_KEY', 'sk')
LIMIT = int(os.getenv('DYNAMODB_LIMIT', 25))
DELIMITER = os.getenv('DYNAMODB_DELIMITER', '#')
LOG_LEVEL = os.getenv('LOG_LEVEL', 'INFO').upper()
logger = Logger(__name__)
logging.basicConfig(level=getattr(logging, LOG_LEVEL))
logger = logging.getLogger(__name__)
serializer = TypeSerializer()
deserializer = TypeDeserializer()
@@ -201,9 +205,9 @@ class TransactKey:
sk: tuple[SortKey, ...] = ()
def __add__(self, sk: SortKey) -> 'TransactKey':
if isinstance(sk, SortKey):
return TransactKey(pk=self.pk, sk=self.sk + (sk,))
raise TypeError('Can only add a SortKey to a TransactKey')
if not isinstance(sk, SortKey):
raise TypeError('Can only add a SortKey to a TransactKey')
return TransactKey(pk=self.pk, sk=self.sk + (sk,))
class Key(ABC, dict):
@@ -267,10 +271,33 @@ class KeyPair(Key):
return cls(*pair)
class TransactionCanceledReason(TypedDict):
code: str
message: str
operation: dict
class TransactionCanceledException(Exception):
def __init__(self, msg: str = '', reasons: list = []) -> None:
super().__init__(msg)
self.msg = msg
self.reasons: list[TransactionCanceledReason] = reasons
class TransactOperation:
def __init__(
self,
op: dict,
exc_cls: type[Exception] | None = None,
) -> None:
self.op = op
self.exc_cls = exc_cls
class TransactItems:
def __init__(self, table_name: str) -> None:
self.table_name = table_name
self.items: list[dict] = []
self.items: list[TransactOperation] = []
def put(
self,
@@ -278,6 +305,7 @@ class TransactItems:
item: dict,
table_name: str | None = None,
cond_expr: str | None = None,
exc_cls: Type[Exception] | None = None,
) -> None:
attrs: dict = {}
@@ -288,13 +316,16 @@ class TransactItems:
table_name = self.table_name
self.items.append(
dict(
Put=dict(
TableName=table_name,
Item=serialize(item),
**attrs,
)
)
TransactOperation(
{
'Put': dict(
TableName=table_name,
Item=serialize(item),
**attrs,
)
},
exc_cls,
),
)
def update(
@@ -322,13 +353,15 @@ class TransactItems:
table_name = self.table_name
self.items.append(
dict(
Update=dict(
TableName=table_name,
Key=serialize(key),
UpdateExpression=update_expr,
**attrs,
)
TransactOperation(
{
'Update': dict(
TableName=table_name,
Key=serialize(key),
UpdateExpression=update_expr,
**attrs,
)
}
)
)
@@ -348,12 +381,14 @@ class TransactItems:
table_name = self.table_name
self.items.append(
dict(
Get=dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
TransactOperation(
{
'Get': dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
}
)
)
@@ -381,12 +416,14 @@ class TransactItems:
table_name = self.table_name
self.items.append(
dict(
Delete=dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
TransactOperation(
{
'Delete': dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
}
)
)
@@ -411,12 +448,14 @@ class TransactItems:
table_name = self.table_name
self.items.append(
dict(
ConditionCheck=dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
TransactOperation(
{
'ConditionCheck': dict(
TableName=table_name,
Key=serialize(key),
**attrs,
)
}
)
)
@@ -589,7 +628,7 @@ class DynamoDBPersistenceLayer:
def transact_get_items(self, transact_items: TransactItems) -> list[dict[str, Any]]:
try:
response = self.dynamodb_client.transact_get_items(
TransactItems=transact_items.items # type: ignore
TransactItems=[item.op for item in transact_items.items] # type: ignore
)
except ClientError as err:
logger.exception(err)
@@ -603,11 +642,30 @@ class DynamoDBPersistenceLayer:
def transact_write_items(self, transact_items: TransactItems) -> bool:
try:
self.dynamodb_client.transact_write_items(
TransactItems=transact_items.items # type: ignore
TransactItems=[item.op for item in transact_items.items] # type: ignore
)
except ClientError as err:
logger.exception(err)
raise
error_msg = glom(err, 'response.Error.Message', default='')
cancellation_reasons = err.response.get('CancellationReasons', [])
reasons = []
for idx, reason in enumerate(cancellation_reasons):
if reason.get('Code') == 'None':
continue
item = transact_items.items[idx]
if item.exc_cls:
raise item.exc_cls(error_msg)
reasons.append(
{
'code': reason.get('Code'),
'message': reason.get('Message'),
'operation': item.op,
}
)
raise TransactionCanceledException(error_msg, reasons)
else:
return True
@@ -637,7 +695,7 @@ class DynamoDBCollection:
----------
persistence_layer: DynamoDBPersistenceLayer
The persistence layer instance responsible for direct DynamoDB operations.
exception_cls: Type[Exception], optional
exc_cls: Type[Exception], optional
The exception class to be raised when a requested item is not found.
tz: str, optional
The timezone identifier used for date/time operations.
@@ -647,11 +705,11 @@ class DynamoDBCollection:
self,
persistence_layer: DynamoDBPersistenceLayer,
/,
exception_cls: Type[Exception] = MissingError,
exc_cls: Type[Exception] = MissingError,
tz: str = TZ,
) -> None:
self.persistence_layer = persistence_layer
self.exception_cls = exception_cls
self.exc_cls = exc_cls
self.tz = tz
def get_item(
@@ -660,9 +718,8 @@ class DynamoDBCollection:
path_spec: str | None = None,
/,
raise_on_error: bool = True,
exception_cls: Type[Exception] | None = None,
exc_cls: Type[Exception] | None = None,
default: Any = None,
delimiter: str = '#',
) -> Any:
"""Get an item with the given key.
@@ -686,12 +743,10 @@ class DynamoDBCollection:
A path specification for nested data extraction.
raise_on_error: bool, optional
If True, raises an exception when the item is not found.
exception_cls: Type[Exception], optional
exc_cls: Type[Exception], optional
Exception class to be used if the item is not found.
default: Any, optional
Default value returned if the item is not found.
delimiter: str, optional
Delimiter used in key composition.
Returns
-------
@@ -703,7 +758,7 @@ class DynamoDBCollection:
Exception
Raises the provided exception if the item is not found and raise_on_error is True.
"""
exc_cls = exception_cls or self.exception_cls
exc_cls = exc_cls or self.exc_cls
data = self.persistence_layer.get_item(key)
if raise_on_error and not data:

View File

@@ -3,7 +3,6 @@ from decimal import Decimal
from ipaddress import IPv4Address
import pytest
from botocore.exceptions import ClientError
from layercake.dateutils import ttl
from layercake.dynamodb import (
@@ -94,6 +93,8 @@ def test_transact_write_items(
dynamodb_seeds,
dynamodb_persistence_layer: DynamoDBPersistenceLayer,
):
class EmailConflictError(Exception): ...
transact = TransactItems(dynamodb_persistence_layer.table_name)
transact.put(item=KeyPair('5OxmMjL-ujoR5IMGegQz', '0'))
transact.put(item=KeyPair('cpf', '07879819908'))
@@ -107,10 +108,12 @@ def test_transact_write_items(
ComposeKey('sergio@somosbeta.com.br', 'emails'),
),
cond_expr='attribute_not_exists(sk)',
exc_cls=EmailConflictError,
)
with pytest.raises(ClientError):
with pytest.raises(EmailConflictError) as exc:
dynamodb_persistence_layer.transact_write_items(transact)
# print(exc.value)
def test_collection_get_item(
@@ -151,7 +154,7 @@ def test_collection_get_item(
with pytest.raises(NotFoundError):
collect.get_item(
KeyPair('5OxmMjL-ujoR5IMGegQz', 'notfound'),
exception_cls=NotFoundError,
exc_cls=NotFoundError,
)