Files
saladeaula.digital/id.saladeaula.digital/app/oauth2.py

424 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from authlib.common.security import generate_token
from authlib.common.urls import add_params_to_uri
from authlib.jose import JsonWebKey
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
from authlib.oauth2.rfc6749.hooks import hooked
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_
from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Response
from layercake.dateutils import now, ttl
from layercake.dynamodb import (
DynamoDBPersistenceLayer,
KeyPair,
SortKey,
TransactKey,
)
from layercake.funcs import omit, pick
from boto3clients import dynamodb_client
from config import ISSUER, OAUTH2_DEFAULT_SCOPES, OAUTH2_SCOPES_SUPPORTED, USER_TABLE
from integrations.apigateway_oauth2.authorization_server import (
AuthorizationServer,
)
from integrations.apigateway_oauth2.resource_protector import ResourceProtector
from integrations.apigateway_oauth2.tokens import (
OAuth2AuthorizationCode,
OAuth2Token,
User,
)
from util import read_file_path
logger = Logger(__name__)
dyn = DynamoDBPersistenceLayer(USER_TABLE, dynamodb_client)
private_key = read_file_path('private.pem')
private_jwk = JsonWebKey.import_key(private_key)
# https://docs.authlib.org/en/v0.12/specs/rfc6750.html#authlib.oauth2.rfc6750.BearerToken.GRANT_TYPES_EXPIRES_IN
GRANT_TYPES_EXPIRES_IN = {
'authorization_code': 3600, # 1 hour
'refresh_token': 3600, # 1 hour
}
class OpenIDCode(OpenIDCode_):
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
if not request.payload:
raise ValueError('Missing request payload')
nonce_ = dyn.get_item(
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
)
return bool(nonce_)
def get_jwt_config(self, grant):
return {
'key': private_key,
'alg': 'RS256',
'iss': ISSUER,
'exp': 3600, # 1 hour
}
def generate_user_info(self, user: User, scope: str) -> UserInfo:
user_info = UserInfo(
sub=user.id,
name=user.name,
email=user.email,
email_verified=user.email_verified,
).filter(scope)
if user.scope:
# Used to define permission granularity
user_info['scope'] = user.scope
return user_info
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
"""The authorization code grant type is used to obtain both access tokens and
refresh tokens and is optimized for confidential clients.
Since this is a redirection-based flow, the client must be capable of
interacting with the resource owners user-agent (typically a web browser)
and capable of receiving incoming requests (via redirection) from the
authorization server:
- https://docs.authlib.org/en/v0.12/specs/rfc6749.html#authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant
"""
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
'none',
]
def save_authorization_code(
self,
code: str,
request: OAuth2Request,
) -> None:
"""Save authorization_code for later use."""
if not request.payload:
raise ValueError('Missing request payload')
if not request.user:
raise ValueError('Missing request user')
client_id: str = request.payload.client_id
scope: str = request.payload.scope
data: dict = request.payload.data
user_id: str = request.user
nonce: str | None = data.get('nonce')
code_challenge: str | None = data.get('code_challenge')
code_challenge_method: str | None = data.get('code_challenge_method')
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=10)
with dyn.transact_writer() as transact:
transact.put(
item={
'id': 'OAUTH2#CODE',
'sk': f'CODE#{code}',
'redirect_uri': request.payload.redirect_uri,
'response_type': request.payload.response_type,
'scope': scope,
'client_id': client_id,
'user_id': user_id,
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': code_challenge_method,
'created_at': now_,
'ttl': ttl_,
},
)
if nonce:
transact.put(
item={
'id': 'OAUTH2#CODE',
'sk': f'NONCE#{nonce}',
'client_id': client_id,
'code': code,
'created_at': now_,
'ttl': ttl_,
},
)
def query_authorization_code(
self,
code: str,
client: ClientMixin,
) -> OAuth2AuthorizationCode:
"""Get authorization_code from previously savings."""
auth_code = dyn.get_item(
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
)
return OAuth2AuthorizationCode(
code=code,
**auth_code,
)
def delete_authorization_code(
self,
authorization_code: OAuth2AuthorizationCode,
) -> None:
"""Delete authorization code from database or cache."""
dyn.delete_item(
KeyPair(
pk='OAUTH2#CODE',
sk=f'CODE#{authorization_code.code}',
),
)
def authenticate_user(
self,
authorization_code: OAuth2AuthorizationCode,
) -> User:
"""Authenticate the user related to this authorization_code."""
user = dyn.collection.get_items(
TransactKey(authorization_code.user_id)
+ SortKey('0')
+ SortKey(
sk='SCOPE',
path_spec='scope',
rename_key='scope',
),
)
scope = set(user.get('scope', [])) | OAUTH2_DEFAULT_SCOPES
return User(
**pick(('id', 'name', 'email', 'email_verified'), user),
scope=' '.join(scope),
)
class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
'none',
]
# The authorization server MAY issue a new refresh token
INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin:
"""Get token information with refresh_token string."""
token = dyn.collection.get_item(
KeyPair(
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{refresh_token}',
),
raise_on_error=False,
default=False,
)
if not token:
logger.info('Refresh token not found ', refresh_token=token)
return None # type: ignore
logger.info('Refresh token retrieved', refresh_token=token)
return OAuth2Token(
expires_in=int(token['expires_in']),
issued_at=int(token['issued_at']),
refresh_token=refresh_token,
**omit(('expires_in', 'issued_at'), token),
)
def authenticate_user(self, refresh_token: TokenMixin):
"""Authenticate the user related to this credential."""
return refresh_token.get_user()
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
"""The authorization server MAY revoke the old refresh token after
issuing a new refresh token to the client."""
token = getattr(refresh_token, 'refresh_token', None)
user = refresh_token.get_user()
if not token:
return None
with dyn.transact_writer() as transact:
logger.debug('Revoking old refresh token', refresh_token=token)
transact.delete(
key=KeyPair(
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{token}',
)
)
transact.delete(
key=KeyPair(
pk=user.get_user_id(),
sk=f'SESSION#REFRESH_TOKEN#{token}',
)
)
class TokenExchangeGrant(grants.BaseGrant):
GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange'
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
]
@hooked
def validate_token_request(self):
raise NotImplementedError()
@hooked
def create_token_response(self):
raise NotImplementedError()
class RevocationEndpoint(rfc7009.RevocationEndpoint):
def query_token( # type: ignore
self,
token_string: str,
token_type_hint: str | None = None,
):
result = dyn.collection.get_items(
TransactKey('OAUTH2#TOKEN')
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
flatten_top=False,
)
if not result:
return None
logger.debug('Tokens retrieved', result=result)
if not token_type_hint:
token_type_hint = (
'refresh_token' if 'refresh_token' in result else 'access_token'
)
token = result[token_type_hint]
return OAuth2Token(
expires_in=int(token['expires_in']),
issued_at=int(token['issued_at']),
**{token_type_hint: token_string},
**omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token),
)
def revoke_token(
self,
token: OAuth2Token,
request: OAuth2Request,
):
"""
Mark token as revoked. Since token MUST be unique, it would be dangerous
to delete it. Consider this situation:
- Jane obtained a token XYZ
- Jane revoked (deleted) token XYZ
- Bob generated a new token XYZ
- Jane can use XYZ to access Bobs resource
- https://docs.authlib.org/en/latest/specs/rfc7009.html#authlib.oauth2.rfc7009.RevocationEndpoint.revoke_token
"""
user_id = token.user['id']
result = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
with dyn.transact_writer() as transact:
# Revoke all sessions, access tokens, and refresh tokens
for item in result['items']:
pk, sk = item['id'], item['sk']
# Key pattern: SESSION#{}#{}
*_, kind, token_ = sk.split('#')
transact.delete(key=KeyPair(pk, sk))
transact.delete(
key=KeyPair(
pk='SESSION',
sk=token_,
)
if kind == 'SESSION'
else KeyPair(
pk='OAUTH2#TOKEN',
sk=f'{kind}#{token_}',
)
)
class IssuerParameter(rfc9207.IssuerParameter):
def add_issuer_parameter(
self,
authorization_server: AuthorizationServer,
response: Response,
):
location = response.headers.get('Location')
if self.get_issuer() and location:
# RFC9207 §2
# In authorization responses to the client, including error responses,
# an authorization server supporting this specification MUST indicate
# its identity by including the iss parameter in the response.
new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore
response.headers['Location'] = new_location
def get_issuer(self) -> str:
return ISSUER
def expires_in(client, grant_type: str) -> int:
return GRANT_TYPES_EXPIRES_IN.get(
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
)
def create_token_generator(length: int = 42):
def token_generator(*args, **kwargs):
return generate_token(length)
return token_generator
class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
def get_jwks(self) -> dict[str, list]: # type: ignore
"""Return the JWKs that will be used to sign the JWT access token."""
return {
'keys': [
private_jwk.as_dict(is_private=True),
]
}
server = AuthorizationServer(
persistence_layer=dyn,
scopes_supported=OAUTH2_SCOPES_SUPPORTED,
)
server.register_grant(
AuthorizationCodeGrant,
[
OpenIDCode(require_nonce=False),
CodeChallenge(required=True),
],
)
server.register_token_generator(
'default',
JWTBearerTokenGenerator(
issuer=ISSUER,
refresh_token_generator=create_token_generator(48),
expires_generator=expires_in,
),
)
server.register_grant(TokenExchangeGrant)
server.register_grant(RefreshTokenGrant)
server.register_endpoint(RevocationEndpoint)
server.register_extension(IssuerParameter())
require_oauth = ResourceProtector()