import decimal import json import math from typing import TYPE_CHECKING import boto3 from aws_lambda_powertools import Logger, Tracer from aws_lambda_powertools.utilities.batch import ( BatchProcessor, EventType, process_partial_response, ) from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( DynamoDBRecord, DynamoDBRecordEventName, ) from aws_lambda_powertools.utilities.typing import LambdaContext from layercake.dateutils import now, ttl from utils import diff, table_from_arn if TYPE_CHECKING: from mypy_boto3_events.client import EventBridgeClient else: EventBridgeClient = object client: EventBridgeClient = boto3.client('events') processor = BatchProcessor(event_type=EventType.DynamoDBStreams) tracer = Tracer() logger = Logger() @tracer.capture_method def record_handler(record: DynamoDBRecord): detail_type = record.raw_event['eventName'] table_name: str = table_from_arn(record.event_source_arn) # type: ignore new_image: dict = record.dynamodb.new_image # type: ignore old_image: dict = record.dynamodb.old_image # type: ignore record_ttl: int | None = old_image.get('ttl') changes = diff(new_image, old_image) now_ = now() # Should be EXPIRE if event is REMOVE and TTL has elapsed if ( record.event_name is DynamoDBRecordEventName.REMOVE and record_ttl and ttl() >= record_ttl ): detail_type = 'EXPIRE' detail = { 'keys': record.dynamodb.keys, # type: ignore 'new_image': new_image, 'old_image': old_image, 'changes': changes, # Post-migration: remove the following line 'modified': changes, } result = client.put_events( Entries=[ { 'Source': record.event_source, # type: ignore 'DetailType': detail_type, 'Resources': [table_name], 'Detail': json.dumps(detail, cls=JSONEncoder), 'Time': now_, } ] ) logger.info('Event result', result=result) @tracer.capture_lambda_handler def lambda_handler(event: dict, context: LambdaContext): return process_partial_response( event=event, record_handler=record_handler, processor=processor, context=context, ) class JSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, decimal.Decimal): if o.is_nan(): return math.nan if o % 1 != 0: return float(o.quantize(decimal.Decimal('0.00'))) return int(o) if isinstance(o, set): return list(o) return super().default(o)