import json import psycopg from aws_lambda_powertools import Logger, Tracer from aws_lambda_powertools.shared.json_encoder import Encoder from aws_lambda_powertools.utilities.data_classes import ( DynamoDBStreamEvent, event_source, ) from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import ( DynamoDBRecordEventName, ) from aws_lambda_powertools.utilities.typing import LambdaContext from layercake.strutils import md5_hash from psycopg import sql from config import ( POSTGRES_DB, POSTGRES_HOST, POSTGRES_PASSWORD, POSTGRES_PORT, POSTGRES_USER, ) from utils import table_from_arn class JSONEncoder(Encoder): def default(self, obj): if isinstance(obj, set): return list(obj) return super().default(obj) logger = Logger(__name__) tracer = Tracer() UPSERT_QUERY = sql.SQL(""" INSERT INTO {0} (_id, json) VALUES (%s, %s) ON CONFLICT (_id) DO UPDATE SET json = EXCLUDED.json """) DELETE_QUERY = sql.SQL(""" DELETE FROM {0} WHERE _id = ANY(%s); """) @event_source(data_class=DynamoDBStreamEvent) @logger.inject_lambda_context @tracer.capture_lambda_handler def lambda_handler(event: DynamoDBStreamEvent, context: LambdaContext): with psycopg.connect( dbname=POSTGRES_DB, user=POSTGRES_USER, password=POSTGRES_PASSWORD, host=POSTGRES_HOST, port=POSTGRES_PORT, ) as conn: with conn.cursor() as cur: upsert_batches = {} delete_batches = {} for record in event.records: new_image = record.dynamodb.new_image # type: ignore table_name = table_from_arn(record.event_source_arn) # type: ignore keys = record.dynamodb.keys # type: ignore _id = md5_hash(str(keys)) match record.event_name: case ( DynamoDBRecordEventName.INSERT | DynamoDBRecordEventName.MODIFY ): if table_name not in upsert_batches: upsert_batches[table_name] = [] upsert_batches[table_name].append( (_id, json.dumps(new_image, cls=JSONEncoder)), ) case DynamoDBRecordEventName.REMOVE: if table_name not in delete_batches: delete_batches[table_name] = [] delete_batches[table_name].append(_id) for table_name, rows in upsert_batches.items(): query = UPSERT_QUERY.format(sql.Identifier(table_name)).as_string(conn) cur.executemany(query, rows) for table_name, ids in delete_batches.items(): query = DELETE_QUERY.format(sql.Identifier(table_name)).as_string(conn) cur.execute(query, [ids]) conn.commit()