import os from dataclasses import asdict import authlib.oauth2 as oauth2 import authlib.oauth2.rfc6749.requests as requests from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import NotFoundError from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import ( APIGatewayProxyEventV2, ) from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN from .client import OAuth2Client from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY') OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED') logger = Logger(__name__) class AuthorizationServer(oauth2.AuthorizationServer): def __init__( self, *, persistence_layer: DynamoDBPersistenceLayer, ) -> None: self._persistence_layer = persistence_layer if OAUTH2_SCOPES_SUPPORTED: super().__init__( scopes_supported=set(OAUTH2_SCOPES_SUPPORTED.split()), ) def save_token( self, token: dict, request: requests.OAuth2Request, ) -> None: """Define function to save the generated token into database.""" if not request.payload: raise ValueError('Missing request payload') if not request.user: raise ValueError('Missing request user') now_ = now() client_id = ( request.client.get_client_id() if request.client else request.payload.client_id ) user_id = request.user.id user = asdict(request.user) access_token = token['access_token'] refresh_token = token.get('refresh_token') token_type = token['token_type'] scope = token['scope'] expires_in = int(token['expires_in']) issued_at = int(now_.timestamp()) access_token_ttl = ttl(start_dt=now_, seconds=expires_in) refresh_token_ttl = ttl(start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN) with self._persistence_layer.transact_writer() as transact: transact.put( item={ 'id': user_id, 'sk': f'SESSION#ACCESS_TOKEN#{access_token}', 'client_id': client_id, 'ttl': access_token_ttl, 'created_at': now_, } ) transact.put( item={ 'id': 'OAUTH2#TOKEN', 'sk': f'ACCESS_TOKEN#{access_token}', 'client_id': client_id, 'token_type': token_type, 'refresh_token': refresh_token, 'scope': scope, 'user': user, 'expires_in': expires_in, 'issued_at': issued_at, 'ttl': access_token_ttl, }, ) if refresh_token: transact.put( item={ 'id': user_id, 'sk': f'SESSION#REFRESH_TOKEN#{refresh_token}', 'client_id': client_id, 'ttl': access_token_ttl, 'created_at': now_, } ) transact.put( item={ 'id': 'OAUTH2#TOKEN', 'sk': f'REFRESH_TOKEN#{refresh_token}', 'client_id': client_id, 'token_type': token_type, 'access_token': access_token, 'scope': scope, 'user': user, 'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN, 'issued_at': issued_at, 'ttl': refresh_token_ttl, }, ) return None def query_client( self, client_id: str, ): """Query OAuth client by client_id.""" client = self._persistence_layer.collection.get_item( KeyPair( pk='OAUTH2', sk=f'CLIENT_ID#{client_id}', ), exc_cls=ClientNotFoundError, ) return OAuth2Client( client_id=client_id, client_secret=client['client_secret'], scope=client['scope'], redirect_uris=client['redirect_uris'], response_types=client['response_types'], grant_types=client['grant_types'], token_endpoint_auth_method=client.get('token_endpoint_auth_method', 'none'), ) def create_oauth2_request( self, request: APIGatewayProxyEventV2, ) -> APIGatewayOAuth2Request: return APIGatewayOAuth2Request(request) def create_json_request( self, request: APIGatewayProxyEventV2, ) -> APIGatewayJsonRequest: return APIGatewayJsonRequest(request) def handle_response( self, status: int, body, headers, ): return Response( status_code=status, body=body, headers=headers, ) def send_signal( self, name: str, *args: object, **kwargs: object, ) -> None: # after_authenticate_client # when client is authenticated # after_revoke_token # when token is revoked ... class ClientNotFoundError(NotFoundError): def __init__(self, *_): super().__init__('Client not found')