import time from dataclasses import dataclass from authlib.common.security import generate_token from authlib.common.urls import add_params_to_uri from authlib.jose import jwt from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207 from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants from authlib.oauth2.rfc6750 import BearerTokenGenerator from authlib.oauth2.rfc7636 import CodeChallenge from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_ from authlib.oidc.core import OpenIDCode as OpenIDCode_ from authlib.oidc.core import UserInfo from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.exceptions import NotFoundError from layercake.dateutils import now, ttl from layercake.dynamodb import ( DynamoDBPersistenceLayer, KeyPair, SortKey, TransactKey, ) from layercake.funcs import omit, pick from boto3clients import dynamodb_client from config import ISSUER, OAUTH2_TABLE from integrations.apigateway_oauth2.authorization_server import ( AuthorizationServer, ) from integrations.apigateway_oauth2.tokens import ( OAuth2AuthorizationCode, OAuth2Token, ) from util import read_file_path logger = Logger(__name__) dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client) @dataclass(frozen=True) class User: id: str name: str email: str email_verified: bool = False def get_user_id(self): return self.id class OpenIDCode(OpenIDCode_): def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool: if not request.payload: raise ValueError('Missing request payload') nonce_ = dyn.get_item( KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'), ) return bool(nonce_) def get_jwt_config(self, grant): return { 'key': read_file_path('private.pem'), 'alg': 'RS256', 'kid': 'test', 'iss': ISSUER, 'exp': 3600 * 2, } def generate_user_info(self, user: User, scope: str) -> UserInfo: return UserInfo( sub=user.id, name=user.name, email=user.email, email_verified=user.email_verified, ).filter(scope) class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', 'none', ] def save_authorization_code( self, code: str, request: OAuth2Request, ) -> None: """Save authorization_code for later use.""" if not request.payload: raise ValueError('Missing request payload') if not request.user: raise ValueError('Missing request user') client_id: str = request.payload.client_id scope: str = request.payload.scope data: dict = request.payload.data user_id: str = request.user nonce: str | None = data.get('nonce') code_challenge: str | None = data.get('code_challenge') code_challenge_method: str | None = data.get('code_challenge_method') now_ = now() ttl_ = ttl(start_dt=now_, minutes=10) with dyn.transact_writer() as transact: transact.put( item={ 'id': 'OAUTH2#CODE', 'sk': f'CODE#{code}', 'redirect_uri': request.payload.redirect_uri, 'response_type': request.payload.response_type, 'scope': scope, 'client_id': client_id, 'user_id': user_id, 'nonce': nonce, 'code_challenge': code_challenge, 'code_challenge_method': code_challenge_method, 'created_at': now_, 'ttl': ttl_, }, ) if nonce: transact.put( item={ 'id': 'OAUTH2#CODE', 'sk': f'NONCE#{nonce}', 'client_id': client_id, 'code': code, 'created_at': now_, 'ttl': ttl_, }, ) def query_authorization_code( self, code: str, client: ClientMixin, ) -> OAuth2AuthorizationCode: """Get authorization_code from previously savings.""" auth_code = dyn.get_item( KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'), ) return OAuth2AuthorizationCode( code=code, **auth_code, ) def delete_authorization_code( self, authorization_code: OAuth2AuthorizationCode, ) -> None: """Delete authorization code from database or cache.""" dyn.delete_item( KeyPair( pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}', ), ) def authenticate_user( self, authorization_code: OAuth2AuthorizationCode, ) -> User: """Authenticate the user related to this authorization_code.""" user = dyn.get_item( KeyPair( pk=authorization_code.user_id, sk='0', ), ) return User(**pick(('id', 'name', 'email', 'email_verified'), user)) class RefreshTokenNotFoundError(NotFoundError): def __init__(self, *_): super().__init__('Refresh token not found') class RefreshTokenGrant(grants.RefreshTokenGrant): TOKEN_ENDPOINT_AUTH_METHODS = [ 'client_secret_basic', 'client_secret_post', 'none', ] # The authorization server MAY issue a new refresh token INCLUDE_NEW_REFRESH_TOKEN = True def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin: """Get token information with refresh_token string.""" token = dyn.collection.get_item( KeyPair( pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{refresh_token}', ), exc_cls=RefreshTokenNotFoundError, ) logger.info('Refresh token retrieved', token=token) return OAuth2Token( expires_in=int(token['expires_in']), issued_at=int(token['issued_at']), refresh_token=refresh_token, **omit(('expires_in', 'issued_at'), token), ) def authenticate_user(self, refresh_token: TokenMixin): """Authenticate the user related to this credential.""" return refresh_token.get_user() def revoke_old_credential(self, refresh_token: TokenMixin) -> None: """The authorization server MAY revoke the old refresh token after issuing a new refresh token to the client.""" logger.debug('Revoking old refresh token', refresh_token=refresh_token) token = getattr(refresh_token, 'refresh_token', None) user = refresh_token.get_user() with dyn.transact_writer() as transact: transact.delete( key=KeyPair( pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}', ) ) transact.delete( key=KeyPair( pk=user.get('id'), sk=f'SESSION#REFRESH_TOKEN#{token}', ) ) class RevocationEndpoint(rfc7009.RevocationEndpoint): def query_token( # type: ignore self, token_string: str, token_type_hint: str | None = None, ): t = dyn.collection.get_items( TransactKey('OAUTH2#TOKEN') + SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token') + SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'), flatten_top=False, ) if not t: return None logger.debug('Tokens retrieved', result=t) if not token_type_hint: token_type_hint = ( 'refresh_token' if 'refresh_token' in t else 'access_token' ) token = t[token_type_hint] return OAuth2Token( expires_in=int(token['expires_in']), issued_at=int(token['issued_at']), **{token_type_hint: token_string}, **omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token), ) def revoke_token( self, token: OAuth2Token, request: OAuth2Request, ): user_id = token.user['id'] r = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION')) with dyn.transact_writer() as transact: # Revoke all sessions, access tokens, and refresh tokens for x in r['items']: pk, sk = x['id'], x['sk'] *_, kind, idx = sk.split('#') transact.delete(key=KeyPair(pk, sk)) transact.delete( key=KeyPair( pk='SESSION', sk=idx, ) if kind == 'SESSION' else KeyPair( pk='OAUTH2#TOKEN', sk=f'{kind}#{idx}', ) ) class IssuerParameter(rfc9207.IssuerParameter): def add_issuer_parameter( self, authorization_server: AuthorizationServer, response: Response, ): location = response.headers.get('Location') if self.get_issuer() and location: # RFC9207 ยง2 # In authorization responses to the client, including error responses, # an authorization server supporting this specification MUST indicate # its identity by including the iss parameter in the response. new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore response.headers['Location'] = new_location def get_issuer(self) -> str: return ISSUER class JWTBearerTokenGenerator(JWTBearerTokenGenerator_): def get_jwks(self): return read_file_path('private.pem') def access_token_generator(self, client, grant_type, user, scope): now = int(time.time()) expires_in = now + self._get_expires_in(client, grant_type) token_data = { 'iss': self.issuer, 'exp': expires_in, 'client_id': client.get_client_id(), 'iat': now, 'jti': self.get_jti(client, grant_type, user, scope), 'scope': scope, } # In cases of access tokens obtained through grants where a resource owner is # involved, such as the authorization code grant, the value of 'sub' SHOULD # correspond to the subject identifier of the resource owner. if user: token_data['sub'] = user.get_user_id() # In cases of access tokens obtained through grants where no resource owner is # involved, such as the client credentials grant, the value of 'sub' SHOULD # correspond to an identifier the authorization server uses to indicate the # client application. else: token_data['sub'] = client.get_client_id() # If the request includes a 'resource' parameter (as defined in [RFC8707]), the # resulting JWT access token 'aud' claim SHOULD have the same value as the # 'resource' parameter in the request. # TODO: Implement this with RFC8707 if False: # pragma: no cover ... # If the request does not include a 'resource' parameter, the authorization # server MUST use a default resource indicator in the 'aud' claim. If a 'scope' # parameter is present in the request, the authorization server SHOULD use it to # infer the value of the default resource indicator to be used in the 'aud' # claim. The mechanism through which scopes are associated with default resource # indicator values is outside the scope of this specification. else: token_data['aud'] = self.get_audiences(client, user, scope) # If the values in the 'scope' parameter refer to different default resource # indicator values, the authorization server SHOULD reject the request with # 'invalid_scope' as described in Section 4.1.2.1 of [RFC6749]. # TODO: Implement this with RFC8707 if auth_time := self.get_auth_time(user): token_data['auth_time'] = auth_time # The meaning and processing of acr Claim Values is out of scope for this # specification. if acr := self.get_acr(user): token_data['acr'] = acr # The definition of particular values to be used in the amr Claim is beyond the # scope of this specification. if amr := self.get_amr(user): token_data['amr'] = amr # Authorization servers MAY return arbitrary attributes not defined in any # existing specification, as long as the corresponding claim names are collision # resistant or the access tokens are meant to be used only within a private # subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details. token_data.update(self.get_extra_claims(client, grant_type, user, scope)) # This specification registers the 'application/at+jwt' media type, which can # be used to indicate that the content is a JWT access token. JWT access tokens # MUST include this media type in the 'typ' header parameter to explicitly # declare that the JWT represents an access token complying with this profile. # Per the definition of 'typ' in Section 4.1.9 of [RFC7515], it is RECOMMENDED # that the 'application/' prefix be omitted. Therefore, the 'typ' value used # SHOULD be 'at+jwt'. header = {'alg': self.alg, 'typ': 'at+jwt', 'kid': 'k1'} access_token = jwt.encode( header, token_data, key=self.get_jwks(), check=False, ) return access_token.decode() GRANT_TYPES_EXPIRES_IN = { 'refresh_token': 600, } def expires_in(client, grant_type: str) -> int: return GRANT_TYPES_EXPIRES_IN.get( grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN ) def create_token_generator(length: int = 42): def token_generator(*args, **kwargs): return generate_token(length) return token_generator server = AuthorizationServer(persistence_layer=dyn) server.register_grant( AuthorizationCodeGrant, [ OpenIDCode(require_nonce=False), CodeChallenge(required=True), ], ) server.register_token_generator( 'default', JWTBearerTokenGenerator( issuer=ISSUER, refresh_token_generator=create_token_generator(48), expires_generator=expires_in, ), ) server.register_grant(RefreshTokenGrant) server.register_endpoint(RevocationEndpoint) server.register_extension(IssuerParameter())