wip
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
|
||||
import authlib.oauth2 as oauth2
|
||||
import authlib.oauth2.rfc6749.requests as requests
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
||||
from aws_lambda_powertools.event_handler.api_gateway import Response
|
||||
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
|
||||
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
|
||||
APIGatewayProxyEventV2,
|
||||
)
|
||||
from layercake.dateutils import now, ttl
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
||||
|
||||
from config import ACCESS_TOKEN_EXP_SECONDS, REFRESH_TOKEN_EXP_SECONDS
|
||||
|
||||
from .client import OAuth2Client
|
||||
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
|
||||
|
||||
DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
|
||||
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
|
||||
|
||||
|
||||
class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
persistence_layer: DynamoDBPersistenceLayer,
|
||||
) -> None:
|
||||
self._persistence_layer = persistence_layer
|
||||
|
||||
super().__init__(
|
||||
scopes_supported=(
|
||||
set(OAUTH2_SCOPES_SUPPORTED.split()) if OAUTH2_SCOPES_SUPPORTED else []
|
||||
)
|
||||
)
|
||||
|
||||
self.register_token_generator(
|
||||
'default',
|
||||
BearerTokenGenerator(
|
||||
access_token_generator=create_token_generator(42),
|
||||
refresh_token_generator=create_token_generator(48),
|
||||
),
|
||||
)
|
||||
|
||||
def save_token(
|
||||
self,
|
||||
token: dict,
|
||||
request: requests.OAuth2Request,
|
||||
) -> None:
|
||||
if not request.payload:
|
||||
raise ValueError('Missing request payload')
|
||||
|
||||
if not request.user:
|
||||
raise ValueError('Missing request user')
|
||||
|
||||
now_ = now()
|
||||
client_id = request.payload.client_id
|
||||
access_token = token['access_token']
|
||||
refresh_token = token.get('refresh_token')
|
||||
token_type = token['token_type']
|
||||
scope = token['scope']
|
||||
issued_at = int(now_.timestamp())
|
||||
|
||||
with self._persistence_layer.transact_writer() as transact:
|
||||
transact.put(
|
||||
item={
|
||||
'id': 'OAUTH2#TOKEN',
|
||||
'sk': f'ACCESS_TOKEN#{access_token}',
|
||||
'client_id': client_id,
|
||||
'token_type': token_type,
|
||||
'refresh_token': refresh_token,
|
||||
'scope': scope,
|
||||
'user': request.user,
|
||||
'expires_in': ACCESS_TOKEN_EXP_SECONDS,
|
||||
'issued_at': issued_at,
|
||||
'ttl': ttl(start_dt=now_, seconds=ACCESS_TOKEN_EXP_SECONDS),
|
||||
},
|
||||
)
|
||||
|
||||
if refresh_token:
|
||||
transact.put(
|
||||
item={
|
||||
'id': 'OAUTH2#TOKEN',
|
||||
'sk': f'REFRESH_TOKEN#{refresh_token}',
|
||||
'client_id': client_id,
|
||||
'token_type': token_type,
|
||||
'access_token': access_token,
|
||||
'scope': scope,
|
||||
'user': request.user,
|
||||
'expires_in': REFRESH_TOKEN_EXP_SECONDS,
|
||||
'issued_at': issued_at,
|
||||
'ttl': ttl(start_dt=now_, seconds=REFRESH_TOKEN_EXP_SECONDS),
|
||||
},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def query_client(self, client_id: str):
|
||||
client = self._persistence_layer.collection.get_item(
|
||||
KeyPair(pk='OAUTH2', sk=f'CLIENT_ID#{client_id}'),
|
||||
exc_cls=ClientNotFoundError,
|
||||
)
|
||||
|
||||
_, client_id = client.get(DYNAMODB_SORT_KEY, '').split('#')
|
||||
|
||||
return OAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client['client_secret'],
|
||||
scope=client['scope'],
|
||||
redirect_uris=client['redirect_uris'],
|
||||
response_types=client['response_types'],
|
||||
grant_types=client['grant_types'],
|
||||
token_endpoint_auth_method=client.get(
|
||||
'token_endpoint_auth_method', 'client_secret_basic'
|
||||
),
|
||||
)
|
||||
|
||||
def create_oauth2_request(
|
||||
self, request: APIGatewayProxyEventV2
|
||||
) -> APIGatewayOAuth2Request:
|
||||
return APIGatewayOAuth2Request(request)
|
||||
|
||||
def create_json_request(
|
||||
self, request: APIGatewayProxyEventV2
|
||||
) -> APIGatewayJsonRequest:
|
||||
return APIGatewayJsonRequest(request)
|
||||
|
||||
def handle_response(self, status: int, body, headers):
|
||||
return Response(
|
||||
status_code=status,
|
||||
body=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def send_signal(self, name: str, *args, **kwargs) -> None:
|
||||
# after_authenticate_client
|
||||
# when client is authenticated
|
||||
|
||||
# after_revoke_token
|
||||
# when token is revoked
|
||||
...
|
||||
|
||||
|
||||
class ClientNotFoundError(NotFoundError):
|
||||
def __init__(self, *_):
|
||||
super().__init__('Client not found')
|
||||
|
||||
|
||||
def create_token_generator(length: int = 42):
|
||||
def token_generator(*args, **kwargs):
|
||||
return generate_token(length)
|
||||
|
||||
return token_generator
|
||||
Reference in New Issue
Block a user