import os from collections import defaultdict from authlib.oauth2 import AuthorizationServer as _AuthorizationServer from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import ( APIGatewayProxyEventV2, ) OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED') class OAuth2Payload(_OAuth2Payload): def __init__(self, request: APIGatewayProxyEventV2): self._request = request @property def data(self): return self._request.query_string_parameters @property def datalist(self) -> dict[str, list]: values = defaultdict(list) for k, v in self.data.items(): values[k].extend([v]) return values class JsonRequest(_JsonRequest): def __init__(self, request: APIGatewayProxyEventV2): uri = f'https://{request.request_context.domain_name}' super().__init__( request.request_context.http.method, uri, request.headers, ) class OAuth2Request(_OAuth2Request): def __init__(self, request: APIGatewayProxyEventV2): uri = f'https://{request.request_context.domain_name}' super().__init__( request.request_context.http.method, uri, request.headers, ) self.payload = OAuth2Payload(request) class OAuth2Client(ClientMixin): def __init__( self, client_id: str, redirect_uris: list, response_types: list, ) -> None: self.client_id = client_id self.redirect_uris = redirect_uris self.response_types = response_types def get_client_id(self): return self.client_id def get_default_redirect_uri(self) -> str: # type: ignore if self.redirect_uris: return self.redirect_uris[0] def check_response_type(self, response_type): return response_type in self.response_types def check_redirect_uri(self, redirect_uri): return redirect_uri in self.redirect_uris class OAuth2Token(TokenMixin): ... class AuthorizationServer(_AuthorizationServer): def __init__(self, query_client, save_token) -> None: super().__init__( scopes_supported=OAUTH2_SCOPES_SUPPORTED, ) self._query_client = query_client self._save_token = save_token def save_token(self, token, request): return self._save_token(token, request) def query_client(self, client_id: str): return self._query_client(client_id) def create_oauth2_request(self, request: APIGatewayProxyEventV2) -> OAuth2Request: return OAuth2Request(request) def create_json_request(self, request: APIGatewayProxyEventV2) -> JsonRequest: return JsonRequest(request) def handle_response(self, status, body, headers): return Response( status_code=status, body=body, headers=headers, )