from http import HTTPStatus from aws_lambda_powertools.event_handler.api_gateway import ( APIGatewayHttpResolver, Response, ) from aws_lambda_powertools.event_handler.exceptions import ( BadRequestError, NotFoundError, ServiceError, ) from aws_lambda_powertools.event_handler.middlewares import ( BaseMiddlewareHandler, NextMiddleware, ) from layercake.dynamodb import ComposeKey, DynamoDBCollection, KeyPair from pydantic import UUID4, BaseModel from auth import AuthFlowType from .authentication_middleware import User class Tenant(BaseModel): id: UUID4 | str name: str class TenantMiddleware(BaseMiddlewareHandler): """Middleware that associates a Tenant instance with the request context based on the authentication flow. For API authentication (`AuthFlowType.API_AUTH`), it assigns tenant information directly from the authorizer context. For user authentication (`AuthFlowType.USER_AUTH`), it gets the Tenant ID from the specified request header. Parameters ---------- collect : DynamoDBCollection The DynamoDB collection used to validate user access and retrieve tenant information. header : str, optional The request header name containing the tenant ID. Defaults to `'X-Tenant'`. """ def __init__( self, collect: DynamoDBCollection, /, header: str = 'X-Tenant', ) -> None: self.collect = collect self.header = header def handler( self, app: APIGatewayHttpResolver, next_middleware: NextMiddleware, ) -> Response: context = app.current_event.request_context.authorizer.get_lambda auth_flow_type = context.get('auth_flow_type') if auth_flow_type == AuthFlowType.API_AUTH: app.append_context(tenant=Tenant(**context['tenant'])) if auth_flow_type == AuthFlowType.USER_AUTH: app.append_context( tenant=_tenant( app.current_event.headers.get(self.header), app.context.get('user'), # type: ignore collect=self.collect, ) ) return next_middleware(app) class ForbiddenError(ServiceError): def __init__(self, *args, **kwargs): super().__init__(HTTPStatus.FORBIDDEN, 'Deny') def _tenant( tenant_id: str | None, user: User, /, collect: DynamoDBCollection, ) -> Tenant: """Get a Tenant instance based on the provided tenant_id and user's access permissions. Parameters ---------- tenant_id : str | None The identifier of the tenant. Must not be None or empty. user : User The user attempting to access the tenant. collect : DynamoDBCollection The DynamoDB collection used to retrieve tenant information. Returns ------- Tenant The Tenant instance corresponding to the provided tenant_id. Raises ------ BadRequestError If tenant_id is not provided. ForbiddenError If the user lacks the necessary ACL permissions for the specified tenant_id. NotFoundError If tenant not found. """ if not tenant_id: raise BadRequestError('Missing tenant') # Ensure user has ACL collect.get_item( KeyPair(user.id, ComposeKey(tenant_id, prefix='acls')), exception_cls=ForbiddenError, ) # For root tenant, provide the default Tenant if tenant_id == '*': return Tenant(id=tenant_id, name='default') obj = collect.get_item(KeyPair(tenant_id, '0'), exception_cls=NotFoundError) return Tenant.parse_obj(obj)