from authlib.common.urls import add_params_to_uri from authlib.oauth2 import OAuth2Request, rfc9207 from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oidc.core import OpenIDCode as OpenIDCode_ from authlib.oidc.core import UserInfo from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import NotFoundError from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from layercake.funcs import omit, pick from boto3clients import dynamodb_client from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE from integrations.apigateway_oauth2.authorization_server import ( AuthorizationServer, ) from integrations.apigateway_oauth2.tokens import ( OAuth2AuthorizationCode, OAuth2Token, ) logger = Logger(__name__) oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client) class OpenIDCode(OpenIDCode_): def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool: if not request.payload: raise ValueError('Missing request payload') nonce_ = oauth2_layer.get_item( KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'), ) return bool(nonce_) def get_jwt_config(self, grant): return { 'key': 'secret-key', 'alg': JWT_ALGORITHM, 'iss': ISSUER, 'exp': 3600, } def generate_user_info(self, user: dict, scope: str) -> UserInfo: return UserInfo( sub=user['id'], name=user['name'], email=user['email'], email_verified=user.get('email_verified', False), ).filter(scope) class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', 'none', ] def save_authorization_code( self, code: str, request: OAuth2Request, ) -> None: if not request.payload: raise ValueError('Missing request payload') if not request.user: raise ValueError('Missing request user') client_id: str = request.payload.client_id data: dict = request.payload.data user: dict = request.user nonce: str | None = data.get('nonce') code_challenge: str | None = data.get('code_challenge') code_challenge_method: str | None = data.get('code_challenge_method') now_ = now() ttl_ = ttl(start_dt=now_, minutes=10) with oauth2_layer.transact_writer() as transact: transact.put( item={ 'id': 'OAUTH2#CODE', 'sk': f'CODE#{code}', 'redirect_uri': request.payload.redirect_uri, 'response_type': request.payload.response_type, 'scope': request.payload.scope, 'client_id': client_id, 'user_id': user['id'], 'nonce': nonce, 'code_challenge': code_challenge, 'code_challenge_method': code_challenge_method, 'created_at': now_, 'ttl': ttl_, }, ) if nonce: transact.put( item={ 'id': 'OAUTH2#CODE', 'sk': f'NONCE#{nonce}', 'client_id': client_id, 'code': code, 'created_at': now_, 'ttl': ttl_, }, ) def query_authorization_code( self, code: str, client: ClientMixin, ) -> OAuth2AuthorizationCode: auth_code = oauth2_layer.get_item( KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'), ) return OAuth2AuthorizationCode( code=code, **auth_code, ) def delete_authorization_code( self, authorization_code: OAuth2AuthorizationCode, ) -> None: oauth2_layer.delete_item( KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}'), ) def authenticate_user( self, authorization_code: OAuth2AuthorizationCode, ) -> dict: user = oauth2_layer.get_item( KeyPair(pk=authorization_code.user_id, sk='0'), ) return pick(('id', 'name', 'email', 'email_verified'), user) class RefreshTokenNotFoundError(NotFoundError): def __init__(self, *_): super().__init__('Refresh token not found') class RefreshTokenGrant(grants.RefreshTokenGrant): TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', 'none', ] INCLUDE_NEW_REFRESH_TOKEN = True def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin: token = oauth2_layer.collection.get_item( KeyPair( pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{refresh_token}', ), exc_cls=RefreshTokenNotFoundError, ) logger.info('Refresh token retrieved', token=token) return OAuth2Token( expires_in=int(token['expires_in']), issued_at=int(token['issued_at']), refresh_token=refresh_token, **omit(('expires_in', 'issued_at'), token), ) def authenticate_user(self, refresh_token: TokenMixin): return refresh_token.get_user() def revoke_old_credential(self, refresh_token: TokenMixin) -> None: logger.info('Revoking old refresh token', refresh_token=refresh_token) token = getattr(refresh_token, 'refresh_token', None) if token: oauth2_layer.delete_item( KeyPair(pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}') ) class IssuerParameter(rfc9207.IssuerParameter): def add_issuer_parameter( self, authorization_server: AuthorizationServer, response: Response, ): location = response.headers.get('Location') if self.get_issuer() and location: # RFC9207 ยง2 # In authorization responses to the client, including error responses, # an authorization server supporting this specification MUST indicate # its identity by including the iss parameter in the response. new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore response.headers['Location'] = new_location def get_issuer(self) -> str: return ISSUER server = AuthorizationServer(persistence_layer=oauth2_layer) server.register_grant( AuthorizationCodeGrant, [ OpenIDCode(require_nonce=False), CodeChallenge(required=True), ], ) server.register_grant(RefreshTokenGrant) server.register_extension(IssuerParameter())