Files
saladeaula.digital/id.saladeaula.digital/app/oauth2.py
2025-09-17 16:51:35 -03:00

313 lines
9.8 KiB
Python

from authlib.common.urls import add_params_to_uri
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
from layercake.dateutils import now, ttl
from layercake.dynamodb import (
DynamoDBPersistenceLayer,
KeyPair,
SortKey,
TransactKey,
)
from layercake.funcs import omit, pick
from boto3clients import dynamodb_client
from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE
from integrations.apigateway_oauth2.authorization_server import (
AuthorizationServer,
)
from integrations.apigateway_oauth2.tokens import (
OAuth2AuthorizationCode,
OAuth2Token,
)
logger = Logger(__name__)
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
class OpenIDCode(OpenIDCode_):
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
if not request.payload:
raise ValueError('Missing request payload')
nonce_ = dyn.get_item(
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
)
return bool(nonce_)
def get_jwt_config(self, grant):
return {
'key': 'secret-key',
'alg': JWT_ALGORITHM,
'iss': ISSUER,
'exp': 3600,
}
def generate_user_info(self, user: dict, scope: str) -> UserInfo:
return UserInfo(
sub=user['id'],
name=user['name'],
email=user['email'],
email_verified=user.get('email_verified', False),
).filter(scope)
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
'none',
]
def save_authorization_code(
self,
code: str,
request: OAuth2Request,
) -> None:
"""Save authorization_code for later use."""
if not request.payload:
raise ValueError('Missing request payload')
if not request.user:
raise ValueError('Missing request user')
client_id: str = request.payload.client_id
scope: str = request.payload.scope
data: dict = request.payload.data
user_id: str = request.user
nonce: str | None = data.get('nonce')
code_challenge: str | None = data.get('code_challenge')
code_challenge_method: str | None = data.get('code_challenge_method')
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=10)
with dyn.transact_writer() as transact:
transact.put(
item={
'id': 'OAUTH2#CODE',
'sk': f'CODE#{code}',
'redirect_uri': request.payload.redirect_uri,
'response_type': request.payload.response_type,
'scope': scope,
'client_id': client_id,
'user_id': user_id,
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': code_challenge_method,
'created_at': now_,
'ttl': ttl_,
},
)
if nonce:
transact.put(
item={
'id': 'OAUTH2#CODE',
'sk': f'NONCE#{nonce}',
'client_id': client_id,
'code': code,
'created_at': now_,
'ttl': ttl_,
},
)
def query_authorization_code(
self,
code: str,
client: ClientMixin,
) -> OAuth2AuthorizationCode:
"""Get authorization_code from previously savings."""
auth_code = dyn.get_item(
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
)
return OAuth2AuthorizationCode(
code=code,
**auth_code,
)
def delete_authorization_code(
self,
authorization_code: OAuth2AuthorizationCode,
) -> None:
"""Delete authorization code from database or cache."""
dyn.delete_item(
KeyPair(
pk='OAUTH2#CODE',
sk=f'CODE#{authorization_code.code}',
),
)
def authenticate_user(
self,
authorization_code: OAuth2AuthorizationCode,
) -> dict:
"""Authenticate the user related to this authorization_code."""
user = dyn.get_item(
KeyPair(
pk=authorization_code.user_id,
sk='0',
),
)
return pick(('id', 'name', 'email', 'email_verified'), user)
class RefreshTokenNotFoundError(NotFoundError):
def __init__(self, *_):
super().__init__('Refresh token not found')
class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
'none',
]
# The authorization server MAY issue a new refresh token
INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
"""Get token information with refresh_token string."""
token = dyn.collection.get_item(
KeyPair(
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{refresh_token}',
),
exc_cls=RefreshTokenNotFoundError,
)
logger.info('Refresh token retrieved', token=token)
return OAuth2Token(
expires_in=int(token['expires_in']),
issued_at=int(token['issued_at']),
refresh_token=refresh_token,
**omit(('expires_in', 'issued_at'), token),
)
def authenticate_user(self, refresh_token: TokenMixin):
"""Authenticate the user related to this credential."""
return refresh_token.get_user()
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
"""The authorization server MAY revoke the old refresh token after
issuing a new refresh token to the client."""
logger.debug('Revoking old refresh token', refresh_token=refresh_token)
token = getattr(refresh_token, 'refresh_token', None)
user = refresh_token.get_user()
with dyn.transact_writer() as transact:
transact.delete(
key=KeyPair(
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{token}',
)
)
transact.delete(
key=KeyPair(
pk=user.get('id'),
sk=f'SESSION#REFRESH_TOKEN#{token}',
)
)
class RevocationEndpoint(rfc7009.RevocationEndpoint):
def query_token( # type: ignore
self,
token_string: str,
token_type_hint: str | None = None,
):
result = dyn.collection.get_items(
TransactKey('OAUTH2#TOKEN')
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
flatten_top=False,
)
if not result:
return None
logger.debug('Tokens retrieved', result=result)
if not token_type_hint:
token_type_hint = (
'refresh_token' if 'refresh_token' in result else 'access_token'
)
token = result[token_type_hint]
return OAuth2Token(
expires_in=int(token['expires_in']),
issued_at=int(token['issued_at']),
**{token_type_hint: token_string},
**omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token),
)
def revoke_token(
self,
token: OAuth2Token,
request: OAuth2Request,
):
user_id = token.user['id']
r = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
with dyn.transact_writer() as transact:
# Revoke all sessions, access tokens, and refresh tokens
for x in r['items']:
pk, sk = x['id'], x['sk']
*_, kind, idx = sk.split('#')
transact.delete(key=KeyPair(pk, sk))
transact.delete(
key=KeyPair(
pk='SESSION',
sk=idx,
)
if kind == 'SESSION'
else KeyPair(
pk='OAUTH2#TOKEN',
sk=f'{kind}#{idx}',
)
)
class IssuerParameter(rfc9207.IssuerParameter):
def add_issuer_parameter(
self,
authorization_server: AuthorizationServer,
response: Response,
):
location = response.headers.get('Location')
if self.get_issuer() and location:
# RFC9207 §2
# In authorization responses to the client, including error responses,
# an authorization server supporting this specification MUST indicate
# its identity by including the iss parameter in the response.
new_location = add_params_to_uri(location, {'iss': self.get_issuer()}) # type: ignore
response.headers['Location'] = new_location
def get_issuer(self) -> str:
return ISSUER
server = AuthorizationServer(persistence_layer=dyn)
server.register_grant(
AuthorizationCodeGrant,
[
OpenIDCode(require_nonce=False),
CodeChallenge(required=True),
],
)
server.register_grant(RefreshTokenGrant)
server.register_endpoint(RevocationEndpoint)
server.register_extension(IssuerParameter())