from datetime import date, datetime, time, timedelta from decimal import Decimal from http import HTTPStatus from typing import Annotated, TypedDict from uuid import uuid4 import pytz from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler.api_gateway import Router from aws_lambda_powertools.event_handler.exceptions import ( ServiceError, ) from aws_lambda_powertools.event_handler.openapi.params import Body from aws_lambda_powertools.shared.functions import extract_event_from_common_models from layercake.batch import BatchProcessor from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from layercake.extra_types import CpfStr, NameStr from layercake.strutils import md5_hash from pydantic import UUID4, BaseModel, EmailStr, Field, FutureDate from boto3clients import dynamodb_client from config import DEDUP_WINDOW_OFFSET_DAYS, ENROLLMENT_TABLE, TZ, USER_TABLE from exceptions import ConflictError from middlewares.authentication_middleware import User as Authenticated logger = Logger(__name__) router = Router() dyn = DynamoDBPersistenceLayer(ENROLLMENT_TABLE, dynamodb_client) processor = BatchProcessor() class SubscriptionRequiredError(ServiceError): def __init__(self, msg: str | dict): super().__init__(HTTPStatus.NOT_ACCEPTABLE, msg) class DeduplicationConflictError(ConflictError): ... class User(BaseModel): id: str | UUID4 name: NameStr email: EmailStr cpf: CpfStr class Course(BaseModel): id: UUID4 name: str access_period: int unit_price: Decimal = Field(exclude=True) class DeduplicationWindow(BaseModel): offset_days: int class Subscription(BaseModel): billing_day: int class Enrollment(BaseModel): id: UUID4 = Field(default_factory=uuid4) user: User course: Course scheduled_for: FutureDate | None = None deduplication_window: DeduplicationWindow | None = None class Org(BaseModel): id: str | UUID4 name: str @router.post('/') def enroll( org_id: Annotated[str | UUID4, Body(embed=True)], enrollments: Annotated[tuple[Enrollment, ...], Body(embed=True)], ): now_ = now() created_by: Authenticated = router.context['user'] org = dyn.collection.get_items( KeyPair( pk=str(org_id), sk='0', table_name=USER_TABLE, ) + KeyPair( pk=str(org_id), sk='METADATA#SUBSCRIPTION', rename_key='subscription', table_name=USER_TABLE, ) + KeyPair( pk='SUBSCRIPTION', sk=f'ORG#{org_id}', rename_key='subscribed', table_name=USER_TABLE, ) ) if 'subscribed' not in org: raise SubscriptionRequiredError('Organization not subscribed') ctx = { 'org': Org.model_validate(org), 'created_by': created_by, 'subscription': Subscription.model_validate(org['subscription']), } immediate = [e for e in enrollments if not e.scheduled_for] later = [e for e in enrollments if e.scheduled_for] with processor(immediate, enroll_now, ctx) as batch: now_out = batch.process() with processor(later, enroll_later, ctx) as batch: later_out = batch.process() def fmt(r): return { 'status': r.status.value, 'input_record': extract_event_from_common_models(r.input_record), 'output': extract_event_from_common_models(r.output), 'cause': r.cause, } item = { 'id': f'SUBMISSION#ORG#{org_id}', 'sk': now_, 'enrolled': list(map(fmt, now_out)) if now_out else None, 'scheduled': list(map(fmt, later_out)) if later_out else None, 'ttl': ttl(start_dt=now_, days=30 * 3), 'created_by': { 'id': created_by.id, 'name': created_by.name, }, } try: dyn.put_item(item=item) except Exception as exc: logger.exception(exc) finally: return item Context = TypedDict( 'Context', { 'org': Org, 'subscription': Subscription, 'created_by': Authenticated, }, ) def enroll_now(enrollment: Enrollment, context: Context): now_ = now() user = enrollment.user course = enrollment.course org: Org = context['org'] subscription: Subscription = context['subscription'] created_by: Authenticated = context['created_by'] lock_hash = md5_hash(f'{user.id}{course.id}') access_expires_at = now_ + timedelta(days=course.access_period) deduplication_window = enrollment.deduplication_window offset_days = ( int(deduplication_window.offset_days) if deduplication_window else DEDUP_WINDOW_OFFSET_DAYS ) dedup_lock_ttl = ttl( start_dt=now_, days=course.access_period - offset_days, ) with dyn.transact_writer() as transact: transact.put( item={ 'id': enrollment.id, 'sk': '0', 'score': None, 'progress': 0, 'status': 'PENDING', 'user': user.model_dump(), 'course': course.model_dump(), 'access_expires_at': access_expires_at, 'subscription_covered': True, 'org_id': org.id, 'created_at': now_, } ) transact.put( item={ 'id': enrollment.id, 'sk': 'ORG', 'name': org.name, 'org_id': org.id, 'created_at': now_, } ) transact.put( item={ 'id': enrollment.id, 'sk': 'CANCEL_POLICY', 'created_at': now_, } ) transact.put( item={ 'id': enrollment.id, 'sk': 'METADATA#SUBSCRIPTION_COVERED', 'org_id': org.id, 'billing_day': subscription.billing_day, 'created_at': now_, } ) transact.put( item={ 'id': enrollment.id, 'sk': 'CREATED_BY', 'name': created_by.name, 'user_id': created_by.id, 'created_at': now_, } ) transact.put( item={ 'id': enrollment.id, 'sk': 'LOCK', 'hash': lock_hash, 'created_at': now_, 'ttl': dedup_lock_ttl, }, ) transact.put( item={ 'id': 'LOCK', 'sk': lock_hash, 'enrollment_id': enrollment.id, 'created_at': now_, 'ttl': dedup_lock_ttl, }, cond_expr='attribute_not_exists(sk)', exc_cls=DeduplicationConflictError, ) # The deduplication window can be recalculated based on user settings. if deduplication_window: transact.put( item={ 'id': enrollment.id, 'sk': 'METADATA#DEDUPLICATION_WINDOW', 'offset_days': offset_days, 'created_at': now_, }, ) return enrollment def enroll_later(enrollment: Enrollment, context: Context): now_ = now() user = enrollment.user course = enrollment.course scheduled_for = date_to_midnight(enrollment.scheduled_for) # type: ignore deduplication_window = enrollment.deduplication_window org: Org = context['org'] subscription: Subscription = context['subscription'] created_by: Authenticated = context['created_by'] lock_hash = md5_hash(f'{user.id}{course.id}') with dyn.transact_writer() as transact: pk = f'SCHEDULED#ORG#{org.id}' sk = f'{scheduled_for.isoformat()}#{lock_hash}' transact.put( item={ 'id': pk, 'sk': sk, 'user': user.model_dump(), 'course': course.model_dump(), 'org_name': org.name, 'created_by': { 'id': created_by.id, 'name': created_by.name, }, 'subscription_billing_day': subscription.billing_day, 'ttl': ttl(start_dt=scheduled_for), 'created_at': now_, } | ( {'dedup_window_offset_days': deduplication_window.offset_days} if deduplication_window else {} ), ) transact.put( item={ 'id': 'LOCK#SCHEDULED', 'sk': lock_hash, 'scheduled': { 'id': pk, 'sk': sk, }, 'created_at': now_, }, cond_expr='attribute_not_exists(sk)', exc_cls=DeduplicationConflictError, ) return enrollment def date_to_midnight(dt: date) -> datetime: return datetime.combine(dt, time(0, 0)).replace(tzinfo=pytz.timezone(TZ))