import os import authlib.oauth2 as oauth2 import authlib.oauth2.rfc6749.requests as requests from authlib.common.security import generate_token from authlib.oauth2.rfc6750 import BearerTokenGenerator from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import NotFoundError from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import ( APIGatewayProxyEventV2, ) from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN from .client import OAuth2Client from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY') OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED') GRANT_TYPES_EXPIRES_IN = { 'refresh_token': 600, 'authorization_code': 600, } logger = Logger(__name__) class AuthorizationServer(oauth2.AuthorizationServer): def __init__( self, *, persistence_layer: DynamoDBPersistenceLayer, ) -> None: self._persistence_layer = persistence_layer super().__init__( scopes_supported=( set(OAUTH2_SCOPES_SUPPORTED.split()) if OAUTH2_SCOPES_SUPPORTED else [] ) ) def expires_in(client, grant_type: str) -> int: return GRANT_TYPES_EXPIRES_IN.get( grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN ) self.register_token_generator( 'default', BearerTokenGenerator( access_token_generator=create_token_generator(42), refresh_token_generator=create_token_generator(48), expires_generator=expires_in, ), ) def save_token( self, token: dict, request: requests.OAuth2Request, ) -> None: if not request.payload: raise ValueError('Missing request payload') if not request.user: raise ValueError('Missing request user') now_ = now() client_id = request.payload.client_id access_token = token['access_token'] refresh_token = token.get('refresh_token') token_type = token['token_type'] scope = token['scope'] expires_in = int(token['expires_in']) issued_at = int(now_.timestamp()) with self._persistence_layer.transact_writer() as transact: transact.put( item={ 'id': 'OAUTH2#TOKEN', 'sk': f'ACCESS_TOKEN#{access_token}', 'client_id': client_id, 'token_type': token_type, 'refresh_token': refresh_token, 'scope': scope, 'user': request.user, 'expires_in': expires_in, 'issued_at': issued_at, 'ttl': ttl(start_dt=now_, seconds=expires_in), }, ) if refresh_token: transact.put( item={ 'id': 'OAUTH2#TOKEN', 'sk': f'REFRESH_TOKEN#{refresh_token}', 'client_id': client_id, 'token_type': token_type, 'access_token': access_token, 'scope': scope, 'user': request.user, 'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN, 'issued_at': issued_at, 'ttl': ttl( start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN ), }, ) return None def query_client(self, client_id: str): client = self._persistence_layer.collection.get_item( KeyPair(pk='OAUTH2', sk=f'CLIENT_ID#{client_id}'), exc_cls=ClientNotFoundError, ) return OAuth2Client( client_id=client_id, client_secret=client['client_secret'], scope=client['scope'], redirect_uris=client['redirect_uris'], response_types=client['response_types'], grant_types=client['grant_types'], token_endpoint_auth_method=client.get( 'token_endpoint_auth_method', 'client_secret_basic' ), ) def create_oauth2_request( self, request: APIGatewayProxyEventV2, ) -> APIGatewayOAuth2Request: return APIGatewayOAuth2Request(request) def create_json_request( self, request: APIGatewayProxyEventV2, ) -> APIGatewayJsonRequest: return APIGatewayJsonRequest(request) def handle_response( self, status: int, body, headers, ): return Response( status_code=status, body=body, headers=headers, ) def send_signal( self, name: str, *args: object, **kwargs: object, ) -> None: # after_authenticate_client # when client is authenticated # after_revoke_token # when token is revoked ... class ClientNotFoundError(NotFoundError): def __init__(self, *_): super().__init__('Client not found') def create_token_generator(length: int = 42): def token_generator(*args, **kwargs): return generate_token(length) return token_generator