Files
saladeaula.digital/id.saladeaula.digital/app/integrations/apigateway_oauth2/authorization_server.py
2025-09-26 14:45:27 -03:00

182 lines
5.6 KiB
Python

import os
from dataclasses import asdict
import authlib.oauth2 as oauth2
import authlib.oauth2.rfc6749.requests as requests
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
APIGatewayProxyEventV2,
)
from layercake.dateutils import now, ttl
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN
from .client import OAuth2Client
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
logger = Logger(__name__)
class AuthorizationServer(oauth2.AuthorizationServer):
def __init__(
self,
*,
persistence_layer: DynamoDBPersistenceLayer,
) -> None:
self._persistence_layer = persistence_layer
if OAUTH2_SCOPES_SUPPORTED:
super().__init__(
scopes_supported=set(OAUTH2_SCOPES_SUPPORTED.split()),
)
def save_token(
self,
token: dict,
request: requests.OAuth2Request,
) -> None:
"""Define function to save the generated token into database."""
if not request.payload:
raise ValueError('Missing request payload')
if not request.user:
raise ValueError('Missing request user')
now_ = now()
client_id = (
request.client.get_client_id()
if request.client
else request.payload.client_id
)
user_id = request.user.id
user = asdict(request.user)
access_token = token['access_token']
refresh_token = token.get('refresh_token')
token_type = token['token_type']
scope = token['scope']
expires_in = int(token['expires_in'])
issued_at = int(now_.timestamp())
access_token_ttl = ttl(start_dt=now_, seconds=expires_in)
refresh_token_ttl = ttl(start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN)
with self._persistence_layer.transact_writer() as transact:
transact.put(
item={
'id': user_id,
'sk': f'SESSION#ACCESS_TOKEN#{access_token}',
'client_id': client_id,
'ttl': access_token_ttl,
'created_at': now_,
}
)
transact.put(
item={
'id': 'OAUTH2#TOKEN',
'sk': f'ACCESS_TOKEN#{access_token}',
'client_id': client_id,
'token_type': token_type,
'refresh_token': refresh_token,
'scope': scope,
'user': user,
'expires_in': expires_in,
'issued_at': issued_at,
'ttl': access_token_ttl,
},
)
if refresh_token:
transact.put(
item={
'id': user_id,
'sk': f'SESSION#REFRESH_TOKEN#{refresh_token}',
'client_id': client_id,
'ttl': access_token_ttl,
'created_at': now_,
}
)
transact.put(
item={
'id': 'OAUTH2#TOKEN',
'sk': f'REFRESH_TOKEN#{refresh_token}',
'client_id': client_id,
'token_type': token_type,
'access_token': access_token,
'scope': scope,
'user': user,
'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
'issued_at': issued_at,
'ttl': refresh_token_ttl,
},
)
return None
def query_client(
self,
client_id: str,
):
client = self._persistence_layer.collection.get_item(
KeyPair(
pk='OAUTH2',
sk=f'CLIENT_ID#{client_id}',
),
exc_cls=ClientNotFoundError,
)
return OAuth2Client(
client_id=client_id,
client_secret=client['client_secret'],
scope=client['scope'],
redirect_uris=client['redirect_uris'],
response_types=client['response_types'],
grant_types=client['grant_types'],
token_endpoint_auth_method=client.get('token_endpoint_auth_method', 'none'),
)
def create_oauth2_request(
self,
request: APIGatewayProxyEventV2,
) -> APIGatewayOAuth2Request:
return APIGatewayOAuth2Request(request)
def create_json_request(
self,
request: APIGatewayProxyEventV2,
) -> APIGatewayJsonRequest:
return APIGatewayJsonRequest(request)
def handle_response(
self,
status: int,
body,
headers,
):
return Response(
status_code=status,
body=body,
headers=headers,
)
def send_signal(
self,
name: str,
*args: object,
**kwargs: object,
) -> None:
# after_authenticate_client
# when client is authenticated
# after_revoke_token
# when token is revoked
...
class ClientNotFoundError(NotFoundError):
def __init__(self, *_):
super().__init__('Client not found')