from functools import reduce from authlib.common.security import generate_token from authlib.common.urls import add_params_to_uri from authlib.jose import JsonWebKey from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207 from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants from authlib.oauth2.rfc6749.hooks import hooked from authlib.oauth2.rfc6749.util import scope_to_list from authlib.oauth2.rfc6750 import BearerTokenGenerator from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_ from authlib.oidc.core import OpenIDCode as OpenIDCode_ from authlib.oidc.core import UserInfo from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler.api_gateway import Response from layercake.dateutils import now, ttl from layercake.dynamodb import ( DynamoDBPersistenceLayer, KeyPair, SortKey, TransactKey, ) from layercake.funcs import omit, pick from boto3clients import dynamodb_client from config import ISSUER, OAUTH2_DEFAULT_SCOPES, OAUTH2_SCOPES_SUPPORTED, USER_TABLE from integrations.apigateway_oauth2.authorization_server import ( AuthorizationServer, ) from integrations.apigateway_oauth2.resource_protector import ResourceProtector from integrations.apigateway_oauth2.tokens import ( OAuth2AuthorizationCode, OAuth2Token, User, ) from util import read_file_path logger = Logger(__name__) dyn = DynamoDBPersistenceLayer(USER_TABLE, dynamodb_client) private_key = read_file_path('private.pem') private_jwk = JsonWebKey.import_key(private_key) # https://docs.authlib.org/en/v0.12/specs/rfc6750.html#authlib.oauth2.rfc6750.BearerToken.GRANT_TYPES_EXPIRES_IN GRANT_TYPES_EXPIRES_IN = { 'authorization_code': 3600, # 1 hour 'refresh_token': 3600, # 1 hour } def get_user_scope(user_id: str) -> set: items = dyn.collection.query( KeyPair(pk=user_id, sk='SCOPE#'), ).get('items', []) scope = reduce(lambda acc, cur: acc + scope_to_list(cur['scope']), items, []) return OAUTH2_DEFAULT_SCOPES | set(scope) class OpenIDCode(OpenIDCode_): def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool: if not request.payload: raise ValueError('Missing request payload') nonce_ = dyn.get_item( KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'), ) return bool(nonce_) def get_jwt_config(self, grant): return { 'key': private_key, 'alg': 'RS256', 'iss': ISSUER, 'exp': 3600, # 1 hour } def generate_user_info(self, user: User, scope: str) -> UserInfo: user_info = UserInfo( sub=user.id, name=user.name, email=user.email, email_verified=user.email_verified, ).filter(scope) if user.scope: # Used to define permission granularity user_info['scope'] = user.scope return user_info class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): """The authorization code grant type is used to obtain both access tokens and refresh tokens and is optimized for confidential clients. Since this is a redirection-based flow, the client must be capable of interacting with the resource owner’s user-agent (typically a web browser) and capable of receiving incoming requests (via redirection) from the authorization server: - https://docs.authlib.org/en/v0.12/specs/rfc6749.html#authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant """ TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', 'none', ] def save_authorization_code( self, code: str, request: OAuth2Request, ) -> None: """Save authorization_code for later use.""" if not request.payload: raise ValueError('Missing request payload') if not request.user: raise ValueError('Missing request user') client_id: str = request.payload.client_id scope: str = request.payload.scope data: dict = request.payload.data user_id: str = request.user nonce: str | None = data.get('nonce') code_challenge: str | None = data.get('code_challenge') code_challenge_method: str | None = data.get('code_challenge_method') now_ = now() ttl_ = ttl(start_dt=now_, minutes=10) with dyn.transact_writer() as transact: transact.put( item={ 'id': 'OAUTH2#CODE', 'sk': f'CODE#{code}', 'redirect_uri': request.payload.redirect_uri, 'response_type': request.payload.response_type, 'scope': scope, 'client_id': client_id, 'user_id': user_id, 'nonce': nonce, 'code_challenge': code_challenge, 'code_challenge_method': code_challenge_method, 'created_at': now_, 'ttl': ttl_, }, ) if nonce: transact.put( item={ 'id': 'OAUTH2#CODE', 'sk': f'NONCE#{nonce}', 'client_id': client_id, 'code': code, 'created_at': now_, 'ttl': ttl_, }, ) def query_authorization_code( self, code: str, client: ClientMixin, ) -> OAuth2AuthorizationCode: """Get authorization_code from previously savings.""" auth_code = dyn.get_item( KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'), ) return OAuth2AuthorizationCode( code=code, **auth_code, ) def delete_authorization_code( self, authorization_code: OAuth2AuthorizationCode, ) -> None: """Delete authorization code from database or cache.""" dyn.delete_item( KeyPair( pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}', ), ) def authenticate_user( self, authorization_code: OAuth2AuthorizationCode, ) -> User: """Authenticate the user related to this authorization_code.""" user = dyn.collection.get_item( KeyPair(pk=authorization_code.user_id, sk='0'), ) scope = get_user_scope(authorization_code.user_id) return User( **pick(('id', 'name', 'email', 'email_verified'), user), scope=' '.join(scope), ) class RefreshTokenGrant(grants.RefreshTokenGrant): TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', 'none', ] # The authorization server MAY issue a new refresh token INCLUDE_NEW_REFRESH_TOKEN = True def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin: """Get token information with refresh_token string.""" token = dyn.collection.get_item( KeyPair( pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{refresh_token}', ), raise_on_error=False, default=False, ) if not token: logger.info('Refresh token not found ', refresh_token=token) return None # type: ignore logger.info('Refresh token retrieved', refresh_token=token) return OAuth2Token( expires_in=int(token['expires_in']), issued_at=int(token['issued_at']), refresh_token=refresh_token, **omit(('expires_in', 'issued_at'), token), ) def authenticate_user(self, refresh_token: TokenMixin): """Authenticate the user related to this credential.""" return refresh_token.get_user() def revoke_old_credential(self, refresh_token: TokenMixin) -> None: """The authorization server MAY revoke the old refresh token after issuing a new refresh token to the client.""" token = getattr(refresh_token, 'refresh_token', None) user = refresh_token.get_user() if not token: return None with dyn.transact_writer() as transact: logger.debug('Revoking old refresh token', refresh_token=token) transact.delete( key=KeyPair( pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}', ) ) transact.delete( key=KeyPair( pk=user.get_user_id(), sk=f'SESSION#REFRESH_TOKEN#{token}', ) ) class TokenExchangeGrant(grants.BaseGrant): GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange' TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', ] @hooked def validate_token_request(self): raise NotImplementedError() @hooked def create_token_response(self): raise NotImplementedError() class RevocationEndpoint(rfc7009.RevocationEndpoint): def query_token( # type: ignore self, token_string: str, token_type_hint: str | None = None, ): result = dyn.collection.get_items( TransactKey('OAUTH2#TOKEN') + SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token') + SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'), flatten_top=False, ) if not result: return None logger.debug('Tokens retrieved', result=result) if not token_type_hint: token_type_hint = ( 'refresh_token' if 'refresh_token' in result else 'access_token' ) token = result[token_type_hint] return OAuth2Token( expires_in=int(token['expires_in']), issued_at=int(token['issued_at']), **{token_type_hint: token_string}, **omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token), ) def revoke_token( self, token: OAuth2Token, request: OAuth2Request, ): """ Mark token as revoked. Since token MUST be unique, it would be dangerous to delete it. Consider this situation: - Jane obtained a token XYZ - Jane revoked (deleted) token XYZ - Bob generated a new token XYZ - Jane can use XYZ to access Bob’s resource - https://docs.authlib.org/en/latest/specs/rfc7009.html#authlib.oauth2.rfc7009.RevocationEndpoint.revoke_token """ user_id = token.user['id'] result = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION')) with dyn.transact_writer() as transact: # Revoke all sessions, access tokens, and refresh tokens for item in result['items']: pk, sk = item['id'], item['sk'] # Key pattern: SESSION#{}#{} *_, kind, token_ = sk.split('#') transact.delete(key=KeyPair(pk, sk)) transact.delete( key=KeyPair( pk='SESSION', sk=token_, ) if kind == 'SESSION' else KeyPair( pk='OAUTH2#TOKEN', sk=f'{kind}#{token_}', ) ) class IssuerParameter(rfc9207.IssuerParameter): def add_issuer_parameter( self, authorization_server: AuthorizationServer, response: Response, ): location = response.headers.get('Location') if self.get_issuer() and location: # RFC9207 §2 # In authorization responses to the client, including error responses, # an authorization server supporting this specification MUST indicate # its identity by including the iss parameter in the response. new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore response.headers['Location'] = new_location def get_issuer(self) -> str: return ISSUER def expires_in(client, grant_type: str) -> int: return GRANT_TYPES_EXPIRES_IN.get( grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN ) def create_token_generator(length: int = 42): def token_generator(*args, **kwargs): return generate_token(length) return token_generator class JWTBearerTokenGenerator(JWTBearerTokenGenerator_): def get_jwks(self) -> dict[str, list]: # type: ignore """Return the JWKs that will be used to sign the JWT access token.""" return { 'keys': [ private_jwk.as_dict(is_private=True), ] } def get_extra_claims(self, client, grant_type, user, scope): return { 'name': user.name, 'email': user.email, 'email_verified': user.email_verified, } server = AuthorizationServer( persistence_layer=dyn, scopes_supported=OAUTH2_SCOPES_SUPPORTED, ) server.register_grant( AuthorizationCodeGrant, [ OpenIDCode(require_nonce=False), CodeChallenge(required=True), ], ) server.register_token_generator( 'default', JWTBearerTokenGenerator( issuer=ISSUER, refresh_token_generator=create_token_generator(48), expires_generator=expires_in, ), ) server.register_grant(TokenExchangeGrant) server.register_grant(RefreshTokenGrant) server.register_endpoint(RevocationEndpoint) server.register_extension(IssuerParameter()) require_oauth = ResourceProtector()