import os import secrets from collections import defaultdict from urllib.parse import parse_qs from authlib.oauth2 import AuthorizationServer as _AuthorizationServer from authlib.oauth2.rfc6749 import AuthorizationCodeMixin, ClientMixin, TokenMixin from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import ( APIGatewayProxyEventV2, ) OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED') class OAuth2Payload(_OAuth2Payload): def __init__(self, request: APIGatewayProxyEventV2): self._request = request @property def decoded_body(self): # TODO body = parse_qs(self._request.decoded_body, keep_blank_values=True) return {k: v[0] if len(v) == 1 else v for k, v in body.items()} @property def data(self): """Combines query string parameters and the request body""" return self._request.query_string_parameters | self.decoded_body @property def datalist(self) -> dict[str, list]: values = defaultdict(list) for k, v in self.data.items(): values[k].extend([v]) return values class JsonRequest(_JsonRequest): def __init__(self, request: APIGatewayProxyEventV2): uri = f'https://{request.request_context.domain_name}' super().__init__( request.request_context.http.method, uri, request.headers, ) class OAuth2Request(_OAuth2Request): def __init__(self, request: APIGatewayProxyEventV2): uri = f'https://{request.request_context.domain_name}' super().__init__( request.request_context.http.method, uri, request.headers, ) self._request = request self.payload = OAuth2Payload(request) @property def args(self): return self._request.query_string_parameters @property def form(self) -> dict[str, str]: return self.payload.decoded_body class OAuth2Client(ClientMixin): def __init__( self, client_id: str, client_secret: str, redirect_uris: list, response_types: list, grant_types: list, token_endpoint_auth_method: str = 'client_secret_basic', ) -> None: self.client_id = client_id self.client_secret = client_secret self.redirect_uris = redirect_uris self.response_types = response_types self.grant_types = grant_types self.token_endpoint_auth_method = token_endpoint_auth_method def get_client_id(self): return self.client_id def get_default_redirect_uri(self) -> str: # type: ignore if self.redirect_uris: return self.redirect_uris[0] def check_response_type(self, response_type): return response_type in self.response_types def check_redirect_uri(self, redirect_uri): return redirect_uri in self.redirect_uris def check_endpoint_auth_method(self, method, endpoint): if endpoint == 'token': return self.token_endpoint_auth_method == method return True def check_grant_type(self, grant_type): return grant_type in self.grant_types def check_client_secret(self, client_secret): return secrets.compare_digest(self.client_secret, client_secret) class OAuth2Token(TokenMixin): ... class AuthorizationCode(AuthorizationCodeMixin): def __init__( self, user_id: str, code: str, client_id: str, redirect_uri: str, response_type: str, scope: str, code_challenge: str | None = None, code_challenge_method: str | None = None, nonce: str | None = None, ) -> None: self.user_id = user_id self.code = code self.client_id = client_id self.redirect_uri = redirect_uri self.response_type = response_type self.scope = scope self.code_challenge = code_challenge self.code_challenge_method = code_challenge_method self.nonce = nonce def get_redirect_uri(self): return self.redirect_uri def get_scope(self): return self.scope class AuthorizationServer(_AuthorizationServer): def __init__(self, query_client, save_token) -> None: super().__init__( scopes_supported=OAUTH2_SCOPES_SUPPORTED, ) self._query_client = query_client self._save_token = save_token def save_token(self, token, request): return self._save_token(token, request) def query_client(self, client_id: str): return self._query_client(client_id) def create_oauth2_request(self, request: APIGatewayProxyEventV2) -> OAuth2Request: return OAuth2Request(request) def create_json_request(self, request: APIGatewayProxyEventV2) -> JsonRequest: return JsonRequest(request) def handle_response(self, status, body, headers): return Response( status_code=status, body=body, headers=headers, ) def send_signal(self, name, *args, **kwargs): pass