from http import HTTPStatus from aws_lambda_powertools.event_handler.api_gateway import ( APIGatewayHttpResolver, Response, ) from aws_lambda_powertools.event_handler.exceptions import ( BadRequestError, NotFoundError, ServiceError, ) from aws_lambda_powertools.event_handler.middlewares import ( BaseMiddlewareHandler, NextMiddleware, ) from layercake.dynamodb import ( ComposeKey, DynamoDBCollection, KeyPair, SortKey, TransactKey, ) from pydantic import UUID4, BaseModel from auth import AuthFlowType from .authentication_middleware import User class Tenant(BaseModel): id: UUID4 | str name: str class TenantMiddleware(BaseMiddlewareHandler): """Middleware that associates a Tenant instance with the request context based on the authentication flow. For API authentication (`AuthFlowType.API_AUTH`), it assigns tenant information directly from the authorizer context. For user authentication (`AuthFlowType.USER_AUTH`), it gets the Tenant ID from the specified request header. Parameters ---------- collect : DynamoDBCollection The DynamoDB collection used to validate user access and retrieve tenant information. header : str, optional The request header name containing the tenant ID. Defaults to `'X-Tenant'`. """ def __init__( self, collection: DynamoDBCollection, /, header: str = 'X-Tenant', ) -> None: self.collection = collection self.header = header def handler( self, app: APIGatewayHttpResolver, next_middleware: NextMiddleware, ) -> Response: context = app.current_event.request_context.authorizer.get_lambda auth_flow_type = context.get('auth_flow_type') if auth_flow_type == AuthFlowType.API_AUTH: app.append_context(tenant=Tenant(**context['tenant'])) if auth_flow_type == AuthFlowType.USER_AUTH: app.append_context( tenant=_tenant( app.current_event.headers.get(self.header), app.context.get('user'), # type: ignore collection=self.collection, ) ) return next_middleware(app) class ForbiddenError(ServiceError): def __init__(self, *args, **kwargs): super().__init__(HTTPStatus.FORBIDDEN, 'Deny') def _tenant( tenant_id: str | None, user: User, /, collection: DynamoDBCollection, ) -> Tenant: """Get a Tenant instance based on the provided tenant_id and user's access permissions. Parameters ---------- tenant_id : str | None The identifier of the tenant. Must not be None or empty. user : User The user attempting to access the tenant. collection : DynamoDBCollection The DynamoDB collection used to retrieve tenant information. Returns ------- Tenant The Tenant instance corresponding to the provided tenant_id. Raises ------ BadRequestError If tenant_id is not provided. ForbiddenError If the user lacks the necessary ACL permissions for the specified tenant_id. NotFoundError If tenant not found. """ if not tenant_id: raise BadRequestError('Missing tenant') # Ensure user has ACL acls = collection.get_items( TransactKey(user.id) + SortKey(ComposeKey('*', prefix='acls')) + SortKey(ComposeKey(tenant_id, prefix='acls')), flatten_top=False, ) if not acls: raise ForbiddenError # For root tenant, provide the default Tenant if tenant_id == '*': return Tenant(id=tenant_id, name='default') obj = collection.get_item(KeyPair(tenant_id, '0'), exc_cls=NotFoundError) return Tenant.model_validate(obj)