359 lines
11 KiB
Python
359 lines
11 KiB
Python
from authlib.common.security import generate_token
|
|
from authlib.common.urls import add_params_to_uri
|
|
from authlib.jose import JsonWebKey
|
|
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
|
|
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
|
|
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
|
from authlib.oauth2.rfc7636 import CodeChallenge
|
|
from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_
|
|
from authlib.oidc.core import OpenIDCode as OpenIDCode_
|
|
from authlib.oidc.core import UserInfo
|
|
from aws_lambda_powertools import Logger
|
|
from aws_lambda_powertools.event_handler.api_gateway import Response
|
|
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
|
|
from layercake.dateutils import now, ttl
|
|
from layercake.dynamodb import (
|
|
DynamoDBPersistenceLayer,
|
|
KeyPair,
|
|
SortKey,
|
|
TransactKey,
|
|
)
|
|
from layercake.funcs import omit, pick
|
|
|
|
from boto3clients import dynamodb_client
|
|
from config import ISSUER, OAUTH2_TABLE
|
|
from integrations.apigateway_oauth2.authorization_server import (
|
|
AuthorizationServer,
|
|
)
|
|
from integrations.apigateway_oauth2.tokens import (
|
|
OAuth2AuthorizationCode,
|
|
OAuth2Token,
|
|
User,
|
|
)
|
|
from util import read_file_path
|
|
|
|
logger = Logger(__name__)
|
|
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
|
|
private_key = read_file_path('private.pem')
|
|
private_jwk = JsonWebKey.import_key(private_key)
|
|
|
|
|
|
GRANT_TYPES_EXPIRES_IN = {
|
|
'authorization_code': 60 * 3, # 3 minutes
|
|
'refresh_token': 3600, # 1 hour
|
|
}
|
|
|
|
|
|
class OpenIDCode(OpenIDCode_):
|
|
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
|
|
if not request.payload:
|
|
raise ValueError('Missing request payload')
|
|
|
|
nonce_ = dyn.get_item(
|
|
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
|
|
)
|
|
|
|
return bool(nonce_)
|
|
|
|
def get_jwt_config(self, grant):
|
|
return {
|
|
'key': private_key,
|
|
'alg': 'RS256',
|
|
'iss': ISSUER,
|
|
'exp': 3600, # 1 hour
|
|
}
|
|
|
|
def generate_user_info(self, user: User, scope: str) -> UserInfo:
|
|
return UserInfo(
|
|
sub=user.id,
|
|
name=user.name,
|
|
email=user.email,
|
|
email_verified=user.email_verified,
|
|
).filter(scope)
|
|
|
|
|
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
|
TOKEN_ENDPOINT_AUTH_METHODS = [
|
|
'client_secret_basic',
|
|
'client_secret_post',
|
|
'none',
|
|
]
|
|
|
|
def save_authorization_code(
|
|
self,
|
|
code: str,
|
|
request: OAuth2Request,
|
|
) -> None:
|
|
"""Save authorization_code for later use."""
|
|
if not request.payload:
|
|
raise ValueError('Missing request payload')
|
|
|
|
if not request.user:
|
|
raise ValueError('Missing request user')
|
|
|
|
client_id: str = request.payload.client_id
|
|
scope: str = request.payload.scope
|
|
data: dict = request.payload.data
|
|
user_id: str = request.user
|
|
nonce: str | None = data.get('nonce')
|
|
code_challenge: str | None = data.get('code_challenge')
|
|
code_challenge_method: str | None = data.get('code_challenge_method')
|
|
now_ = now()
|
|
ttl_ = ttl(start_dt=now_, minutes=10)
|
|
|
|
with dyn.transact_writer() as transact:
|
|
transact.put(
|
|
item={
|
|
'id': 'OAUTH2#CODE',
|
|
'sk': f'CODE#{code}',
|
|
'redirect_uri': request.payload.redirect_uri,
|
|
'response_type': request.payload.response_type,
|
|
'scope': scope,
|
|
'client_id': client_id,
|
|
'user_id': user_id,
|
|
'nonce': nonce,
|
|
'code_challenge': code_challenge,
|
|
'code_challenge_method': code_challenge_method,
|
|
'created_at': now_,
|
|
'ttl': ttl_,
|
|
},
|
|
)
|
|
|
|
if nonce:
|
|
transact.put(
|
|
item={
|
|
'id': 'OAUTH2#CODE',
|
|
'sk': f'NONCE#{nonce}',
|
|
'client_id': client_id,
|
|
'code': code,
|
|
'created_at': now_,
|
|
'ttl': ttl_,
|
|
},
|
|
)
|
|
|
|
def query_authorization_code(
|
|
self,
|
|
code: str,
|
|
client: ClientMixin,
|
|
) -> OAuth2AuthorizationCode:
|
|
"""Get authorization_code from previously savings."""
|
|
auth_code = dyn.get_item(
|
|
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
|
|
)
|
|
|
|
return OAuth2AuthorizationCode(
|
|
code=code,
|
|
**auth_code,
|
|
)
|
|
|
|
def delete_authorization_code(
|
|
self,
|
|
authorization_code: OAuth2AuthorizationCode,
|
|
) -> None:
|
|
"""Delete authorization code from database or cache."""
|
|
dyn.delete_item(
|
|
KeyPair(
|
|
pk='OAUTH2#CODE',
|
|
sk=f'CODE#{authorization_code.code}',
|
|
),
|
|
)
|
|
|
|
def authenticate_user(
|
|
self,
|
|
authorization_code: OAuth2AuthorizationCode,
|
|
) -> User:
|
|
"""Authenticate the user related to this authorization_code."""
|
|
user = dyn.get_item(
|
|
KeyPair(
|
|
pk=authorization_code.user_id,
|
|
sk='0',
|
|
),
|
|
)
|
|
|
|
return User(**pick(('id', 'name', 'email', 'email_verified'), user))
|
|
|
|
|
|
class RefreshTokenNotFoundError(NotFoundError):
|
|
def __init__(self, *_):
|
|
super().__init__('Refresh token not found')
|
|
|
|
|
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
|
TOKEN_ENDPOINT_AUTH_METHODS = [
|
|
'client_secret_basic',
|
|
'client_secret_post',
|
|
'none',
|
|
]
|
|
|
|
# The authorization server MAY issue a new refresh token
|
|
INCLUDE_NEW_REFRESH_TOKEN = True
|
|
|
|
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
|
|
"""Get token information with refresh_token string."""
|
|
token = dyn.collection.get_item(
|
|
KeyPair(
|
|
pk='OAUTH2#TOKEN',
|
|
sk=f'REFRESH_TOKEN#{refresh_token}',
|
|
),
|
|
exc_cls=RefreshTokenNotFoundError,
|
|
)
|
|
|
|
logger.info('Refresh token retrieved', token=token)
|
|
|
|
return OAuth2Token(
|
|
expires_in=int(token['expires_in']),
|
|
issued_at=int(token['issued_at']),
|
|
refresh_token=refresh_token,
|
|
**omit(('expires_in', 'issued_at'), token),
|
|
)
|
|
|
|
def authenticate_user(self, refresh_token: TokenMixin):
|
|
"""Authenticate the user related to this credential."""
|
|
return refresh_token.get_user()
|
|
|
|
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
|
|
"""The authorization server MAY revoke the old refresh token after
|
|
issuing a new refresh token to the client."""
|
|
|
|
logger.debug('Revoking old refresh token', refresh_token=refresh_token)
|
|
token = getattr(refresh_token, 'refresh_token', None)
|
|
user = refresh_token.get_user()
|
|
|
|
with dyn.transact_writer() as transact:
|
|
transact.delete(
|
|
key=KeyPair(
|
|
pk='OAUTH2#TOKEN',
|
|
sk=f'REFRESH_TOKEN#{token}',
|
|
)
|
|
)
|
|
transact.delete(
|
|
key=KeyPair(
|
|
pk=user.get_user_id(),
|
|
sk=f'SESSION#REFRESH_TOKEN#{token}',
|
|
)
|
|
)
|
|
|
|
|
|
class RevocationEndpoint(rfc7009.RevocationEndpoint):
|
|
def query_token( # type: ignore
|
|
self,
|
|
token_string: str,
|
|
token_type_hint: str | None = None,
|
|
):
|
|
t = dyn.collection.get_items(
|
|
TransactKey('OAUTH2#TOKEN')
|
|
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
|
|
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
|
|
flatten_top=False,
|
|
)
|
|
|
|
if not t:
|
|
return None
|
|
|
|
logger.debug('Tokens retrieved', result=t)
|
|
|
|
if not token_type_hint:
|
|
token_type_hint = (
|
|
'refresh_token' if 'refresh_token' in t else 'access_token'
|
|
)
|
|
|
|
token = t[token_type_hint]
|
|
|
|
return OAuth2Token(
|
|
expires_in=int(token['expires_in']),
|
|
issued_at=int(token['issued_at']),
|
|
**{token_type_hint: token_string},
|
|
**omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token),
|
|
)
|
|
|
|
def revoke_token(
|
|
self,
|
|
token: OAuth2Token,
|
|
request: OAuth2Request,
|
|
):
|
|
user_id = token.user['id']
|
|
r = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
|
|
|
|
with dyn.transact_writer() as transact:
|
|
# Revoke all sessions, access tokens, and refresh tokens
|
|
for x in r['items']:
|
|
pk, sk = x['id'], x['sk']
|
|
*_, kind, idx = sk.split('#')
|
|
|
|
transact.delete(key=KeyPair(pk, sk))
|
|
transact.delete(
|
|
key=KeyPair(
|
|
pk='SESSION',
|
|
sk=idx,
|
|
)
|
|
if kind == 'SESSION'
|
|
else KeyPair(
|
|
pk='OAUTH2#TOKEN',
|
|
sk=f'{kind}#{idx}',
|
|
)
|
|
)
|
|
|
|
|
|
class IssuerParameter(rfc9207.IssuerParameter):
|
|
def add_issuer_parameter(
|
|
self,
|
|
authorization_server: AuthorizationServer,
|
|
response: Response,
|
|
):
|
|
location = response.headers.get('Location')
|
|
|
|
if self.get_issuer() and location:
|
|
# RFC9207 §2
|
|
# In authorization responses to the client, including error responses,
|
|
# an authorization server supporting this specification MUST indicate
|
|
# its identity by including the iss parameter in the response.
|
|
new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore
|
|
response.headers['Location'] = new_location
|
|
|
|
def get_issuer(self) -> str:
|
|
return ISSUER
|
|
|
|
|
|
def expires_in(client, grant_type: str) -> int:
|
|
return GRANT_TYPES_EXPIRES_IN.get(
|
|
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
|
|
)
|
|
|
|
|
|
def create_token_generator(length: int = 42):
|
|
def token_generator(*args, **kwargs):
|
|
return generate_token(length)
|
|
|
|
return token_generator
|
|
|
|
|
|
class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
|
|
def get_jwks(self) -> dict[str, list]: # type: ignore
|
|
"""Return the JWKs that will be used to sign the JWT access token."""
|
|
return {
|
|
'keys': [
|
|
private_jwk.as_dict(is_private=True),
|
|
]
|
|
}
|
|
|
|
|
|
server = AuthorizationServer(persistence_layer=dyn)
|
|
server.register_grant(
|
|
AuthorizationCodeGrant,
|
|
[
|
|
OpenIDCode(require_nonce=False),
|
|
CodeChallenge(required=True),
|
|
],
|
|
)
|
|
server.register_token_generator(
|
|
'default',
|
|
JWTBearerTokenGenerator(
|
|
issuer=ISSUER,
|
|
refresh_token_generator=create_token_generator(48),
|
|
expires_generator=expires_in,
|
|
),
|
|
)
|
|
server.register_grant(RefreshTokenGrant)
|
|
server.register_endpoint(RevocationEndpoint)
|
|
server.register_extension(IssuerParameter())
|