Files
saladeaula.digital/orders-events/app/events/start_fulfillment.py

394 lines
11 KiB
Python

from datetime import date, datetime, time, timedelta
from typing import Annotated, Sequence, TypedDict
from uuid import uuid4
import pytz
from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.data_classes import (
EventBridgeEvent,
event_source,
)
from aws_lambda_powertools.utilities.typing import LambdaContext
from layercake.batch import BatchProcessor, Status
from layercake.dateutils import now, ttl
from layercake.dynamodb import (
DynamoDBPersistenceLayer,
KeyChain,
KeyPair,
SortKey,
TransactKey,
)
from layercake.strutils import md5_hash
from pydantic import UUID4, BaseModel, BeforeValidator, Field
from boto3clients import dynamodb_client
from config import (
COURSE_TABLE,
DEDUP_WINDOW_OFFSET_DAYS,
ENROLLMENT_TABLE,
ORDER_TABLE,
TZ,
USER_TABLE,
)
logger = Logger(__name__)
dyn = DynamoDBPersistenceLayer(ORDER_TABLE, dynamodb_client)
processor = BatchProcessor()
class DeduplicationConflictError(Exception): ...
class EnrollmentConflictError(Exception): ...
class DeadlineExceededError(Exception): ...
class User(BaseModel):
id: str
name: str
email: str
cpf: str
class Course(BaseModel):
id: str
name: str
access_period: int
class Enrollment(BaseModel):
id: Annotated[
UUID4,
BeforeValidator(lambda s: s.removeprefix('ENROLLMENT#')),
] = Field(alias='sk')
user: User
course: Course
scheduled_for: date | None = None
class Org(BaseModel):
id: str | UUID4
name: str
@event_source(data_class=EventBridgeEvent)
@logger.inject_lambda_context
def lambda_handler(event: EventBridgeEvent, context: LambdaContext) -> bool:
new_image = event.detail['new_image']
now_ = now()
order_id = new_image['id']
org_id = new_image['org_id']
order = dyn.collection.get_items(
TransactKey(order_id)
+ SortKey('ITEMS', rename_key='items')
+ SortKey('CREATED_BY', rename_key='created_by')
+ KeyPair(
pk=org_id,
sk='0',
rename_key='org',
table_name=USER_TABLE,
),
)
r = dyn.collection.query(KeyPair(order_id, 'ENROLLMENT#'))
enrollments = [Enrollment(**x) for x in r['items']]
if not enrollments:
courses = _items_to_courses(order['items'])
_release_seats(courses, order_id=order_id, org_id=org_id)
else:
ctx = {
'order_id': order_id,
'org': Org(id=org_id, name=order['org']['name']),
'created_by': order['created_by'],
}
immediate = [e for e in enrollments if not e.scheduled_for]
later = [e for e in enrollments if e.scheduled_for]
with processor(immediate, _enroll_now, ctx) as batch:
now_out = batch.process()
with processor(later, _enroll_later, ctx) as batch:
later_out = batch.process()
# Release seats for enrollments that failed
failed = [x for x in now_out + later_out if x.status == Status.FAIL]
_release_seats(
courses=[x.input_record.course for x in failed],
order_id=order_id,
org_id=org_id,
)
logger.debug('Some enrollments failed', failed=failed)
with dyn.transact_writer() as transact:
for x in failed:
reason = _friendly_reason(x.cause['type']) # type: ignore
transact.update(
key=KeyPair(order_id, f'ENROLLMENT#{x.input_record.id}'),
update_expr='SET #status = :rollback, \
rollback_at = :now, \
reason = :reason',
cond_expr='attribute_exists(sk) AND #status = :pending',
expr_attr_names={
'#status': 'status',
},
expr_attr_values={
':pending': 'PENDING',
':rollback': 'ROLLBACK',
':reason': reason,
':now': now_,
},
)
return dyn.update_item(
key=KeyPair(order_id, new_image['sk']),
update_expr='SET #status = :completed, \
completed_at = :now',
expr_attr_names={
'#status': 'status',
},
expr_attr_values={
':completed': 'COMPLETED',
':now': now_,
},
)
Item = TypedDict('Item', {'id': str, 'quantity': int})
def _release_seats(
courses: Sequence[Course],
*,
order_id: str,
org_id: str,
) -> None:
now_ = now()
with dyn.transact_writer(table_name=ENROLLMENT_TABLE) as transact:
for course in courses:
transact.put(
item={
'id': f'SEAT#ORG#{org_id}',
'sk': f'ORDER#{order_id}#ENROLLMENT#{uuid4()}',
'course': course.model_dump(),
'created_at': now_,
},
)
def _items_to_courses(items: list[Item]) -> tuple[Course, ...]:
pairs = {x['id']: int(x['quantity']) for x in items}
courses = _get_courses(set(pairs.keys()))
return tuple(x for x in courses for _ in range(pairs.get(x.id, 0)))
def _get_courses(ids: set[str]) -> tuple[Course, ...]:
pairs = tuple(
KeyPair(
pk=idx,
sk='0',
table_name=COURSE_TABLE,
)
for idx in ids
)
r = dyn.collection.get_items(
KeyChain(pairs),
flatten_top=False,
)
return tuple(Course(id=idx, **attrs) for idx, attrs in r.items())
def _friendly_reason(reason: str) -> str:
reasons = {
'DeduplicationConflictError': 'DEDUPLICATION',
'DeadlineExceededError': 'DEADLINE',
}
return reasons.get(reason, 'CONFLICT')
CreatedBy = TypedDict('CreatedBy', {'user_id': str, 'name': str})
Context = TypedDict(
'Context',
{
'order_id': str,
'org': Org,
'created_by': CreatedBy,
},
)
def _enroll_now(enrollment: Enrollment, context: Context) -> None:
now_ = now()
user = enrollment.user
course = enrollment.course
order_id = context['order_id']
org = context['org']
created_by = context['created_by']
access_expires_at = now_ + timedelta(days=course.access_period)
lock_hash = md5_hash(f'{user.id}{course.id}')
access_expires_at = now_ + timedelta(days=course.access_period)
offset_days = DEDUP_WINDOW_OFFSET_DAYS
dedup_lock_ttl = ttl(
start_dt=now_,
days=course.access_period - offset_days,
)
with dyn.transact_writer(table_name=ENROLLMENT_TABLE) as transact:
transact.put(
item={
'id': enrollment.id,
'sk': '0',
'score': None,
'progress': 0,
'status': 'PENDING',
'user': user.model_dump(),
'course': course.model_dump(),
'access_expires_at': access_expires_at,
'org_id': org.id,
'created_at': now_,
}
)
transact.put(
item={
'id': enrollment.id,
'sk': 'ORG',
'name': org.name,
'org_id': org.id,
'created_at': now_,
}
)
transact.put(
item={
'id': enrollment.id,
'sk': 'CANCEL_POLICY',
'created_at': now_,
'seat': {'order_id': order_id},
}
)
# Enrollment should know where it comes from
transact.put(
item={
'id': enrollment.id,
'sk': f'LINKED_ENTITY#PARENT#ORDER#{order_id}',
'created_at': now_,
},
cond_expr='attribute_not_exists(sk)',
)
transact.update(
key=KeyPair(order_id, f'ENROLLMENT#{enrollment.id}'),
update_expr='SET #status = :executed, \
executed_at = :now',
cond_expr='attribute_exists(sk) AND #status = :pending',
expr_attr_names={
'#status': 'status',
},
expr_attr_values={
':pending': 'PENDING',
':executed': 'EXECUTED',
':now': now_,
},
table_name=ORDER_TABLE,
exc_cls=EnrollmentConflictError,
)
transact.put(
item={
'id': enrollment.id,
'sk': 'CREATED_BY',
'name': created_by['name'],
'user_id': created_by['user_id'],
'created_at': now_,
}
)
transact.put(
item={
'id': enrollment.id,
'sk': 'LOCK',
'hash': lock_hash,
'created_at': now_,
'ttl': dedup_lock_ttl,
},
)
transact.put(
item={
'id': 'LOCK',
'sk': lock_hash,
'enrollment_id': enrollment.id,
'created_at': now_,
'ttl': dedup_lock_ttl,
},
cond_expr='attribute_not_exists(sk)',
exc_cls=DeduplicationConflictError,
)
def _enroll_later(enrollment: Enrollment, context: Context) -> None:
now_ = now()
user = enrollment.user
course = enrollment.course
org = context['org']
created_by = context['created_by']
order_id = context['order_id']
scheduled_for = _date_to_midnight(enrollment.scheduled_for) # type: ignore
lock_hash = md5_hash(f'{user.id}{course.id}')
if now_ > scheduled_for:
raise DeadlineExceededError('Deadline exceeded')
with dyn.transact_writer(table_name=ENROLLMENT_TABLE) as transact:
pk = f'SCHEDULED#ORG#{org.id}'
sk = f'{scheduled_for.isoformat()}#{lock_hash}'
transact.put(
item={
'id': pk,
'sk': sk,
'user': user.model_dump(),
'course': course.model_dump(),
'org_name': org.name,
'enrollment_id': enrollment.id,
'created_by': {
'id': created_by['user_id'],
'name': created_by['name'],
},
'seat': {'order_id': order_id},
'ttl': ttl(start_dt=scheduled_for),
'scheduled_at': now_,
},
)
transact.put(
item={
'id': order_id,
'sk': f'ENROLLMENT#{enrollment.id}',
'user': user.model_dump(),
'course': course.model_dump(),
'status': 'SCHEDULED',
'scheduled_at': now_,
'created_at': now_,
},
table_name=ORDER_TABLE,
)
transact.put(
item={
'id': 'LOCK#SCHEDULED',
'sk': lock_hash,
'scheduled': {
'id': pk,
'sk': sk,
},
'ttl': ttl(start_dt=scheduled_for),
'created_at': now_,
},
cond_expr='attribute_not_exists(sk)',
exc_cls=DeduplicationConflictError,
)
def _date_to_midnight(dt: date) -> datetime:
return datetime.combine(dt, time(0, 0)).replace(tzinfo=pytz.timezone(TZ))