184 lines
5.7 KiB
Python
184 lines
5.7 KiB
Python
import os
|
|
from dataclasses import asdict
|
|
|
|
import authlib.oauth2 as oauth2
|
|
import authlib.oauth2.rfc6749.requests as requests
|
|
from aws_lambda_powertools import Logger
|
|
from aws_lambda_powertools.event_handler.api_gateway import Response
|
|
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
|
|
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
|
|
APIGatewayProxyEventV2,
|
|
)
|
|
from layercake.dateutils import now, ttl
|
|
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
|
|
|
from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN
|
|
|
|
from .client import OAuth2Client
|
|
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
|
|
|
|
DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
|
|
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
|
|
|
|
logger = Logger(__name__)
|
|
|
|
|
|
class AuthorizationServer(oauth2.AuthorizationServer):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
persistence_layer: DynamoDBPersistenceLayer,
|
|
) -> None:
|
|
self._persistence_layer = persistence_layer
|
|
|
|
if OAUTH2_SCOPES_SUPPORTED:
|
|
super().__init__(
|
|
scopes_supported=set(OAUTH2_SCOPES_SUPPORTED.split()),
|
|
)
|
|
|
|
def save_token(
|
|
self,
|
|
token: dict,
|
|
request: requests.OAuth2Request,
|
|
) -> None:
|
|
"""Define function to save the generated token into database."""
|
|
if not request.payload:
|
|
raise ValueError('Missing request payload')
|
|
|
|
if not request.user:
|
|
raise ValueError('Missing request user')
|
|
|
|
now_ = now()
|
|
client_id = (
|
|
request.client.get_client_id()
|
|
if request.client
|
|
else request.payload.client_id
|
|
)
|
|
user_id = request.user.id
|
|
user = asdict(request.user)
|
|
access_token = token['access_token']
|
|
refresh_token = token.get('refresh_token')
|
|
token_type = token['token_type']
|
|
scope = token['scope']
|
|
expires_in = int(token['expires_in'])
|
|
issued_at = int(now_.timestamp())
|
|
access_token_ttl = ttl(start_dt=now_, seconds=expires_in)
|
|
refresh_token_ttl = ttl(start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN)
|
|
|
|
with self._persistence_layer.transact_writer() as transact:
|
|
transact.put(
|
|
item={
|
|
'id': user_id,
|
|
'sk': f'SESSION#ACCESS_TOKEN#{access_token}',
|
|
'client_id': client_id,
|
|
'ttl': access_token_ttl,
|
|
'created_at': now_,
|
|
}
|
|
)
|
|
transact.put(
|
|
item={
|
|
'id': 'OAUTH2#TOKEN',
|
|
'sk': f'ACCESS_TOKEN#{access_token}',
|
|
'client_id': client_id,
|
|
'token_type': token_type,
|
|
'refresh_token': refresh_token,
|
|
'scope': scope,
|
|
'user': user,
|
|
'expires_in': expires_in,
|
|
'issued_at': issued_at,
|
|
'ttl': access_token_ttl,
|
|
},
|
|
)
|
|
|
|
if refresh_token:
|
|
transact.put(
|
|
item={
|
|
'id': user_id,
|
|
'sk': f'SESSION#REFRESH_TOKEN#{refresh_token}',
|
|
'client_id': client_id,
|
|
'ttl': access_token_ttl,
|
|
'created_at': now_,
|
|
}
|
|
)
|
|
transact.put(
|
|
item={
|
|
'id': 'OAUTH2#TOKEN',
|
|
'sk': f'REFRESH_TOKEN#{refresh_token}',
|
|
'client_id': client_id,
|
|
'token_type': token_type,
|
|
'access_token': access_token,
|
|
'scope': scope,
|
|
'user': user,
|
|
'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
|
|
'issued_at': issued_at,
|
|
'ttl': refresh_token_ttl,
|
|
},
|
|
)
|
|
|
|
return None
|
|
|
|
def query_client(
|
|
self,
|
|
client_id: str,
|
|
):
|
|
"""Query OAuth client by client_id."""
|
|
client = self._persistence_layer.collection.get_item(
|
|
KeyPair(
|
|
pk='OAUTH2',
|
|
sk=f'CLIENT_ID#{client_id}',
|
|
),
|
|
exc_cls=ClientNotFoundError,
|
|
)
|
|
|
|
return OAuth2Client(
|
|
client_id=client_id,
|
|
client_secret=client['client_secret'],
|
|
scope=client['scope'],
|
|
redirect_uris=client['redirect_uris'],
|
|
response_types=client['response_types'],
|
|
grant_types=client['grant_types'],
|
|
token_endpoint_auth_method=client.get('token_endpoint_auth_method', 'none'),
|
|
)
|
|
|
|
def create_oauth2_request(
|
|
self,
|
|
request: APIGatewayProxyEventV2,
|
|
) -> APIGatewayOAuth2Request:
|
|
return APIGatewayOAuth2Request(request)
|
|
|
|
def create_json_request(
|
|
self,
|
|
request: APIGatewayProxyEventV2,
|
|
) -> APIGatewayJsonRequest:
|
|
return APIGatewayJsonRequest(request)
|
|
|
|
def handle_response(
|
|
self,
|
|
status: int,
|
|
body,
|
|
headers,
|
|
):
|
|
return Response(
|
|
status_code=status,
|
|
body=body,
|
|
headers=headers,
|
|
)
|
|
|
|
def send_signal(
|
|
self,
|
|
name: str,
|
|
*args: object,
|
|
**kwargs: object,
|
|
) -> None:
|
|
# after_authenticate_client
|
|
# when client is authenticated
|
|
|
|
# after_revoke_token
|
|
# when token is revoked
|
|
...
|
|
|
|
|
|
class ClientNotFoundError(NotFoundError):
|
|
def __init__(self, *_):
|
|
super().__init__('Client not found')
|