from authlib.oauth2.rfc6749 import TokenMixin, grants from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oidc.core import OpenIDCode as OpenIDCode_ from authlib.oidc.core import UserInfo from aws_lambda_powertools.event_handler.exceptions import NotFoundError from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from layercake.funcs import omit, pick from boto3clients import dynamodb_client from config import DYNAMODB_SORT_KEY, OAUTH2_TABLE from integrations.apigateway_oauth2 import ( AuthorizationCode, AuthorizationServer, OAuth2Client, OAuth2Token, ) oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client) DUMMY_JWT_CONFIG = { 'key': 'secret-key', 'alg': 'HS256', 'iss': 'https://authlib.org', 'exp': 3600, } def create_save_token_func(persistence_layer: DynamoDBPersistenceLayer): def save_token(token, request) -> OAuth2Token: print('save_token') return OAuth2Token() return save_token def create_query_client_func(persistence_layer: DynamoDBPersistenceLayer): class ClientNotFoundError(NotFoundError): def __init__(self, *_): super().__init__('Client not found') def query_client(client_id) -> OAuth2Client: client = persistence_layer.collection.get_item( KeyPair( pk='OAUTH2_CLIENT', sk=f'CLIENT_ID#{client_id}', ), exc_cls=ClientNotFoundError, ) _, client_id = client.get(DYNAMODB_SORT_KEY, '').split('#') return OAuth2Client( client_id=client_id, client_secret=client['secret'], redirect_uris=client['redirect_uris'], response_types=client['response_types'], grant_types=client['grant_types'], token_endpoint_auth_method=client['token_endpoint_auth_method'], ) return query_client class OpenIDCode(OpenIDCode_): def exists_nonce(self, nonce, request): nonce_ = oauth2_layer.get_item( KeyPair( f'OAUTH2_CODE#CLIENT_ID#{request.payload.client_id}', # type:ignore f'NONCE#{nonce}', ) ) return bool(nonce_) def get_jwt_config(self, grant): return DUMMY_JWT_CONFIG def generate_user_info(self, user, scope): return UserInfo( sub=user.id, name=user.name, email=user.email, ).filter(scope) class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] def save_authorization_code(self, code: str, request): client_id: str = request.payload.client_id # type: ignore data: dict = request.payload.data # type: ignore user: dict = request.user # type: ignore nonce: str | None = data.get('nonce') code_challenge: str | None = data.get('code_challenge') code_challenge_method: str | None = data.get('code_challenge_method') now_ = now() ttl_ = ttl(start_dt=now_, minutes=15) with oauth2_layer.transact_writer() as transact: transact.put( item={ 'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}', 'sk': f'CODE#{code}', 'redirect_uri': request.payload.redirect_uri, # type: ignore 'scope': request.payload.scope, # type: ignore 'user_id': user['id'], 'nonce': nonce, 'code_challenge': code_challenge, 'code_challenge_method': code_challenge_method, 'created_at': now_, 'ttl': ttl_, }, ) if nonce: transact.put( item={ 'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}', 'sk': f'NONCE#{nonce}', 'code': code, 'created_at': now_, 'ttl': ttl_, }, ) def query_authorization_code(self, code, client): client_id = client.get_client_id() auth_code = oauth2_layer.get_item( KeyPair( pk=f'OAUTH2_CODE#CLIENT_ID#{client_id}', sk=f'CODE#{code}', ), ) return AuthorizationCode( client_id=client_id, code=code, **omit(('id', 'sk'), auth_code), ) def delete_authorization_code(self, authorization_code): print('authorization_code') def authenticate_user(self, authorization_code): user = oauth2_layer.get_item( KeyPair( pk=authorization_code.user_id, sk='0', ), ) return pick(('id', 'name', 'email'), user) class RefreshTokenGrant(grants.RefreshTokenGrant): INCLUDE_NEW_REFRESH_TOKEN = True def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin: ... def authenticate_user(self, refresh_token): ... def revoke_old_credential(self, refresh_token: TokenMixin) -> None: ... server = AuthorizationServer( query_client=create_query_client_func(oauth2_layer), save_token=create_save_token_func(oauth2_layer), ) server.register_grant( AuthorizationCodeGrant, [ OpenIDCode(require_nonce=False), CodeChallenge(required=True), ], ) server.register_grant(RefreshTokenGrant)