424 lines
13 KiB
Python
424 lines
13 KiB
Python
from authlib.common.security import generate_token
|
||
from authlib.common.urls import add_params_to_uri
|
||
from authlib.jose import JsonWebKey
|
||
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
|
||
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
|
||
from authlib.oauth2.rfc6749.hooks import hooked
|
||
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||
from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_
|
||
from authlib.oidc.core import OpenIDCode as OpenIDCode_
|
||
from authlib.oidc.core import UserInfo
|
||
from aws_lambda_powertools import Logger
|
||
from aws_lambda_powertools.event_handler.api_gateway import Response
|
||
from layercake.dateutils import now, ttl
|
||
from layercake.dynamodb import (
|
||
DynamoDBPersistenceLayer,
|
||
KeyPair,
|
||
SortKey,
|
||
TransactKey,
|
||
)
|
||
from layercake.funcs import omit, pick
|
||
|
||
from boto3clients import dynamodb_client
|
||
from config import ISSUER, OAUTH2_SCOPES_SUPPORTED, OAUTH2_TABLE
|
||
from integrations.apigateway_oauth2.authorization_server import (
|
||
AuthorizationServer,
|
||
)
|
||
from integrations.apigateway_oauth2.resource_protector import ResourceProtector
|
||
from integrations.apigateway_oauth2.tokens import (
|
||
OAuth2AuthorizationCode,
|
||
OAuth2Token,
|
||
User,
|
||
)
|
||
from util import read_file_path
|
||
|
||
logger = Logger(__name__)
|
||
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
|
||
private_key = read_file_path('private.pem')
|
||
private_jwk = JsonWebKey.import_key(private_key)
|
||
|
||
|
||
# https://docs.authlib.org/en/v0.12/specs/rfc6750.html#authlib.oauth2.rfc6750.BearerToken.GRANT_TYPES_EXPIRES_IN
|
||
GRANT_TYPES_EXPIRES_IN = {
|
||
'authorization_code': 60 * 3, # 3 minutes
|
||
'refresh_token': 60 * 3, # 3 minutes
|
||
# 'refresh_token': 3600, # 1 hour
|
||
}
|
||
|
||
|
||
class OpenIDCode(OpenIDCode_):
|
||
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
|
||
if not request.payload:
|
||
raise ValueError('Missing request payload')
|
||
|
||
nonce_ = dyn.get_item(
|
||
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
|
||
)
|
||
|
||
return bool(nonce_)
|
||
|
||
def get_jwt_config(self, grant):
|
||
return {
|
||
'key': private_key,
|
||
'alg': 'RS256',
|
||
'iss': ISSUER,
|
||
'exp': 3600, # 1 hour
|
||
}
|
||
|
||
def generate_user_info(self, user: User, scope: str) -> UserInfo:
|
||
user_info = UserInfo(
|
||
sub=user.id,
|
||
name=user.name,
|
||
email=user.email,
|
||
email_verified=user.email_verified,
|
||
).filter(scope)
|
||
|
||
if user.scope:
|
||
# Used to define permission granularity
|
||
user_info['scope'] = user.scope
|
||
|
||
return user_info
|
||
|
||
|
||
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||
"""The authorization code grant type is used to obtain both access tokens and
|
||
refresh tokens and is optimized for confidential clients.
|
||
|
||
Since this is a redirection-based flow, the client must be capable of
|
||
interacting with the resource owner’s user-agent (typically a web browser)
|
||
and capable of receiving incoming requests (via redirection) from the
|
||
authorization server:
|
||
|
||
- https://docs.authlib.org/en/v0.12/specs/rfc6749.html#authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant
|
||
"""
|
||
|
||
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||
'client_secret_basic',
|
||
'client_secret_post',
|
||
'none',
|
||
]
|
||
|
||
def save_authorization_code(
|
||
self,
|
||
code: str,
|
||
request: OAuth2Request,
|
||
) -> None:
|
||
"""Save authorization_code for later use."""
|
||
if not request.payload:
|
||
raise ValueError('Missing request payload')
|
||
|
||
if not request.user:
|
||
raise ValueError('Missing request user')
|
||
|
||
client_id: str = request.payload.client_id
|
||
scope: str = request.payload.scope
|
||
data: dict = request.payload.data
|
||
user_id: str = request.user
|
||
nonce: str | None = data.get('nonce')
|
||
code_challenge: str | None = data.get('code_challenge')
|
||
code_challenge_method: str | None = data.get('code_challenge_method')
|
||
now_ = now()
|
||
ttl_ = ttl(start_dt=now_, minutes=10)
|
||
|
||
with dyn.transact_writer() as transact:
|
||
transact.put(
|
||
item={
|
||
'id': 'OAUTH2#CODE',
|
||
'sk': f'CODE#{code}',
|
||
'redirect_uri': request.payload.redirect_uri,
|
||
'response_type': request.payload.response_type,
|
||
'scope': scope,
|
||
'client_id': client_id,
|
||
'user_id': user_id,
|
||
'nonce': nonce,
|
||
'code_challenge': code_challenge,
|
||
'code_challenge_method': code_challenge_method,
|
||
'created_at': now_,
|
||
'ttl': ttl_,
|
||
},
|
||
)
|
||
|
||
if nonce:
|
||
transact.put(
|
||
item={
|
||
'id': 'OAUTH2#CODE',
|
||
'sk': f'NONCE#{nonce}',
|
||
'client_id': client_id,
|
||
'code': code,
|
||
'created_at': now_,
|
||
'ttl': ttl_,
|
||
},
|
||
)
|
||
|
||
def query_authorization_code(
|
||
self,
|
||
code: str,
|
||
client: ClientMixin,
|
||
) -> OAuth2AuthorizationCode:
|
||
"""Get authorization_code from previously savings."""
|
||
auth_code = dyn.get_item(
|
||
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
|
||
)
|
||
|
||
return OAuth2AuthorizationCode(
|
||
code=code,
|
||
**auth_code,
|
||
)
|
||
|
||
def delete_authorization_code(
|
||
self,
|
||
authorization_code: OAuth2AuthorizationCode,
|
||
) -> None:
|
||
"""Delete authorization code from database or cache."""
|
||
dyn.delete_item(
|
||
KeyPair(
|
||
pk='OAUTH2#CODE',
|
||
sk=f'CODE#{authorization_code.code}',
|
||
),
|
||
)
|
||
|
||
def authenticate_user(
|
||
self,
|
||
authorization_code: OAuth2AuthorizationCode,
|
||
) -> User:
|
||
"""Authenticate the user related to this authorization_code."""
|
||
user = dyn.collection.get_items(
|
||
TransactKey(authorization_code.user_id)
|
||
+ SortKey('0')
|
||
+ SortKey(
|
||
sk='SCOPE',
|
||
path_spec='scope',
|
||
rename_key='scope',
|
||
),
|
||
)
|
||
|
||
return User(
|
||
**pick(('id', 'name', 'email', 'email_verified'), user),
|
||
scope=' '.join(user['scope']) if 'scope' in user else None,
|
||
)
|
||
|
||
|
||
class TokenExchangeGrant(grants.BaseGrant):
|
||
GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange'
|
||
|
||
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||
'client_secret_basic',
|
||
'client_secret_post',
|
||
]
|
||
|
||
@hooked
|
||
def validate_token_request(self):
|
||
raise NotImplementedError()
|
||
|
||
@hooked
|
||
def create_token_response(self):
|
||
raise NotImplementedError()
|
||
|
||
|
||
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||
'client_secret_basic',
|
||
'client_secret_post',
|
||
'none',
|
||
]
|
||
|
||
# The authorization server MAY issue a new refresh token
|
||
INCLUDE_NEW_REFRESH_TOKEN = True
|
||
|
||
def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin:
|
||
"""Get token information with refresh_token string."""
|
||
token = dyn.collection.get_item(
|
||
KeyPair(
|
||
pk='OAUTH2#TOKEN',
|
||
sk=f'REFRESH_TOKEN#{refresh_token}',
|
||
),
|
||
raise_on_error=False,
|
||
default=False,
|
||
)
|
||
|
||
if not token:
|
||
logger.info('Refresh token not found ', refresh_token=token)
|
||
return None # type: ignore
|
||
|
||
logger.info('Refresh token retrieved', refresh_token=token)
|
||
|
||
return OAuth2Token(
|
||
expires_in=int(token['expires_in']),
|
||
issued_at=int(token['issued_at']),
|
||
refresh_token=refresh_token,
|
||
**omit(('expires_in', 'issued_at'), token),
|
||
)
|
||
|
||
def authenticate_user(self, refresh_token: TokenMixin):
|
||
"""Authenticate the user related to this credential."""
|
||
return refresh_token.get_user()
|
||
|
||
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
|
||
"""The authorization server MAY revoke the old refresh token after
|
||
issuing a new refresh token to the client."""
|
||
|
||
token = getattr(refresh_token, 'refresh_token', None)
|
||
user = refresh_token.get_user()
|
||
|
||
if not token:
|
||
return None
|
||
|
||
with dyn.transact_writer() as transact:
|
||
logger.debug('Revoking old refresh token', refresh_token=token)
|
||
|
||
transact.delete(
|
||
key=KeyPair(
|
||
pk='OAUTH2#TOKEN',
|
||
sk=f'REFRESH_TOKEN#{token}',
|
||
)
|
||
)
|
||
transact.delete(
|
||
key=KeyPair(
|
||
pk=user.get_user_id(),
|
||
sk=f'SESSION#REFRESH_TOKEN#{token}',
|
||
)
|
||
)
|
||
|
||
|
||
class RevocationEndpoint(rfc7009.RevocationEndpoint):
|
||
def query_token( # type: ignore
|
||
self,
|
||
token_string: str,
|
||
token_type_hint: str | None = None,
|
||
):
|
||
result = dyn.collection.get_items(
|
||
TransactKey('OAUTH2#TOKEN')
|
||
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
|
||
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
|
||
flatten_top=False,
|
||
)
|
||
|
||
if not result:
|
||
return None
|
||
|
||
logger.debug('Tokens retrieved', result=result)
|
||
|
||
if not token_type_hint:
|
||
token_type_hint = (
|
||
'refresh_token' if 'refresh_token' in result else 'access_token'
|
||
)
|
||
|
||
token = result[token_type_hint]
|
||
|
||
return OAuth2Token(
|
||
expires_in=int(token['expires_in']),
|
||
issued_at=int(token['issued_at']),
|
||
**{token_type_hint: token_string},
|
||
**omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token),
|
||
)
|
||
|
||
def revoke_token(
|
||
self,
|
||
token: OAuth2Token,
|
||
request: OAuth2Request,
|
||
):
|
||
"""
|
||
Mark token as revoked. Since token MUST be unique, it would be dangerous
|
||
to delete it. Consider this situation:
|
||
|
||
- Jane obtained a token XYZ
|
||
- Jane revoked (deleted) token XYZ
|
||
- Bob generated a new token XYZ
|
||
- Jane can use XYZ to access Bob’s resource
|
||
|
||
- https://docs.authlib.org/en/latest/specs/rfc7009.html#authlib.oauth2.rfc7009.RevocationEndpoint.revoke_token
|
||
"""
|
||
user_id = token.user['id']
|
||
result = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
|
||
|
||
with dyn.transact_writer() as transact:
|
||
# Revoke all sessions, access tokens, and refresh tokens
|
||
for item in result['items']:
|
||
pk, sk = item['id'], item['sk']
|
||
# Key pattern: SESSION#{}#{}
|
||
*_, kind, token_ = sk.split('#')
|
||
|
||
transact.delete(key=KeyPair(pk, sk))
|
||
transact.delete(
|
||
key=KeyPair(
|
||
pk='SESSION',
|
||
sk=token_,
|
||
)
|
||
if kind == 'SESSION'
|
||
else KeyPair(
|
||
pk='OAUTH2#TOKEN',
|
||
sk=f'{kind}#{token_}',
|
||
)
|
||
)
|
||
|
||
|
||
class IssuerParameter(rfc9207.IssuerParameter):
|
||
def add_issuer_parameter(
|
||
self,
|
||
authorization_server: AuthorizationServer,
|
||
response: Response,
|
||
):
|
||
location = response.headers.get('Location')
|
||
|
||
if self.get_issuer() and location:
|
||
# RFC9207 §2
|
||
# In authorization responses to the client, including error responses,
|
||
# an authorization server supporting this specification MUST indicate
|
||
# its identity by including the iss parameter in the response.
|
||
new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore
|
||
response.headers['Location'] = new_location
|
||
|
||
def get_issuer(self) -> str:
|
||
return ISSUER
|
||
|
||
|
||
def expires_in(client, grant_type: str) -> int:
|
||
return GRANT_TYPES_EXPIRES_IN.get(
|
||
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
|
||
)
|
||
|
||
|
||
def create_token_generator(length: int = 42):
|
||
def token_generator(*args, **kwargs):
|
||
return generate_token(length)
|
||
|
||
return token_generator
|
||
|
||
|
||
class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
|
||
def get_jwks(self) -> dict[str, list]: # type: ignore
|
||
"""Return the JWKs that will be used to sign the JWT access token."""
|
||
return {
|
||
'keys': [
|
||
private_jwk.as_dict(is_private=True),
|
||
]
|
||
}
|
||
|
||
|
||
server = AuthorizationServer(
|
||
persistence_layer=dyn,
|
||
scopes_supported=OAUTH2_SCOPES_SUPPORTED,
|
||
)
|
||
server.register_grant(
|
||
AuthorizationCodeGrant,
|
||
[
|
||
OpenIDCode(require_nonce=False),
|
||
CodeChallenge(required=True),
|
||
],
|
||
)
|
||
server.register_token_generator(
|
||
'default',
|
||
JWTBearerTokenGenerator(
|
||
issuer=ISSUER,
|
||
refresh_token_generator=create_token_generator(48),
|
||
expires_generator=expires_in,
|
||
),
|
||
)
|
||
server.register_grant(TokenExchangeGrant)
|
||
server.register_grant(RefreshTokenGrant)
|
||
server.register_endpoint(RevocationEndpoint)
|
||
server.register_extension(IssuerParameter())
|
||
|
||
require_oauth = ResourceProtector()
|