from http import HTTPStatus from aws_lambda_powertools.event_handler.api_gateway import ( APIGatewayHttpResolver, Response, ) from aws_lambda_powertools.event_handler.exceptions import BadRequestError, ServiceError from aws_lambda_powertools.event_handler.middlewares import ( BaseMiddlewareHandler, NextMiddleware, ) from layercake.dynamodb import ComposeKey, DynamoDBCollection, KeyPair from pydantic import UUID4, BaseModel from auth import AuthFlowType from .authorizer_middleware import User class Tenant(BaseModel): id: UUID4 | str name: str class TenantMiddleware(BaseMiddlewareHandler): def __init__(self, collect: DynamoDBCollection) -> None: self.collect = collect def handler( self, app: APIGatewayHttpResolver, next_middleware: NextMiddleware, ) -> Response: context = app.current_event.request_context.authorizer.get_lambda auth_flow_type = context.get('auth_flow_type') if auth_flow_type == AuthFlowType.API_AUTH: app.append_context(tenant=Tenant(**context['tenant'])) if auth_flow_type == AuthFlowType.USER_AUTH: app.append_context( tenant=_tenant( app.current_event.headers.get('x-tenant'), app.context.get('user'), # type: ignore collect=self.collect, ) ) return next_middleware(app) class ForbiddenError(ServiceError): def __init__(self, msg: str): super().__init__(HTTPStatus.FORBIDDEN, msg) def _tenant( tenant_id: str | None, user: User, /, collect: DynamoDBCollection, ) -> Tenant: if not tenant_id: raise BadRequestError('Missing tenant') collect.get_item( KeyPair(user.id, ComposeKey(tenant_id, prefix='acls')), exception_cls=ForbiddenError, ) if tenant_id == '*': return Tenant(id=tenant_id, name='default') obj = collect.get_item(KeyPair(tenant_id, '0')) return Tenant.parse_obj(obj)