wip
This commit is contained in:
@@ -1,180 +0,0 @@
|
||||
import os
|
||||
import secrets
|
||||
from collections import defaultdict
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
|
||||
from authlib.oauth2.rfc6749 import AuthorizationCodeMixin, ClientMixin, TokenMixin
|
||||
from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest
|
||||
from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload
|
||||
from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request
|
||||
from aws_lambda_powertools.event_handler.api_gateway import Response
|
||||
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
|
||||
APIGatewayProxyEventV2,
|
||||
)
|
||||
|
||||
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
|
||||
|
||||
|
||||
class OAuth2Payload(_OAuth2Payload):
|
||||
def __init__(self, request: APIGatewayProxyEventV2):
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def decoded_body(self):
|
||||
# TODO
|
||||
body = parse_qs(self._request.decoded_body, keep_blank_values=True)
|
||||
return {k: v[0] if len(v) == 1 else v for k, v in body.items()}
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""Combines query string parameters and the request body"""
|
||||
return self._request.query_string_parameters | self.decoded_body
|
||||
|
||||
@property
|
||||
def datalist(self) -> dict[str, list]:
|
||||
values = defaultdict(list)
|
||||
|
||||
for k, v in self.data.items():
|
||||
values[k].extend([v])
|
||||
return values
|
||||
|
||||
|
||||
class JsonRequest(_JsonRequest):
|
||||
def __init__(self, request: APIGatewayProxyEventV2):
|
||||
uri = f'https://{request.request_context.domain_name}'
|
||||
|
||||
super().__init__(
|
||||
request.request_context.http.method,
|
||||
uri,
|
||||
request.headers,
|
||||
)
|
||||
|
||||
|
||||
class OAuth2Request(_OAuth2Request):
|
||||
def __init__(self, request: APIGatewayProxyEventV2):
|
||||
uri = f'https://{request.request_context.domain_name}'
|
||||
|
||||
super().__init__(
|
||||
request.request_context.http.method,
|
||||
uri,
|
||||
request.headers,
|
||||
)
|
||||
self._request = request
|
||||
self.payload = OAuth2Payload(request)
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._request.query_string_parameters
|
||||
|
||||
@property
|
||||
def form(self) -> dict[str, str]:
|
||||
return self.payload.decoded_body
|
||||
|
||||
|
||||
class OAuth2Client(ClientMixin):
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uris: list,
|
||||
response_types: list,
|
||||
grant_types: list,
|
||||
token_endpoint_auth_method: str = 'client_secret_basic',
|
||||
) -> None:
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uris = redirect_uris
|
||||
self.response_types = response_types
|
||||
self.grant_types = grant_types
|
||||
self.token_endpoint_auth_method = token_endpoint_auth_method
|
||||
|
||||
def get_client_id(self):
|
||||
return self.client_id
|
||||
|
||||
def get_default_redirect_uri(self) -> str: # type: ignore
|
||||
if self.redirect_uris:
|
||||
return self.redirect_uris[0]
|
||||
|
||||
def check_response_type(self, response_type):
|
||||
return response_type in self.response_types
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def check_endpoint_auth_method(self, method, endpoint):
|
||||
if endpoint == 'token':
|
||||
return self.token_endpoint_auth_method == method
|
||||
|
||||
return True
|
||||
|
||||
def check_grant_type(self, grant_type):
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def check_client_secret(self, client_secret):
|
||||
return secrets.compare_digest(self.client_secret, client_secret)
|
||||
|
||||
|
||||
class OAuth2Token(TokenMixin): ...
|
||||
|
||||
|
||||
class AuthorizationCode(AuthorizationCodeMixin):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
response_type: str,
|
||||
scope: str,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
nonce: str | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.code = code
|
||||
self.client_id = client_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.response_type = response_type
|
||||
self.scope = scope
|
||||
self.code_challenge = code_challenge
|
||||
self.code_challenge_method = code_challenge_method
|
||||
self.nonce = nonce
|
||||
|
||||
def get_redirect_uri(self):
|
||||
return self.redirect_uri
|
||||
|
||||
def get_scope(self):
|
||||
return self.scope
|
||||
|
||||
|
||||
class AuthorizationServer(_AuthorizationServer):
|
||||
def __init__(self, query_client, save_token) -> None:
|
||||
super().__init__(
|
||||
scopes_supported=OAUTH2_SCOPES_SUPPORTED,
|
||||
)
|
||||
|
||||
self._query_client = query_client
|
||||
self._save_token = save_token
|
||||
|
||||
def save_token(self, token, request):
|
||||
return self._save_token(token, request)
|
||||
|
||||
def query_client(self, client_id: str):
|
||||
return self._query_client(client_id)
|
||||
|
||||
def create_oauth2_request(self, request: APIGatewayProxyEventV2) -> OAuth2Request:
|
||||
return OAuth2Request(request)
|
||||
|
||||
def create_json_request(self, request: APIGatewayProxyEventV2) -> JsonRequest:
|
||||
return JsonRequest(request)
|
||||
|
||||
def handle_response(self, status, body, headers):
|
||||
return Response(
|
||||
status_code=status,
|
||||
body=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def send_signal(self, name, *args, **kwargs):
|
||||
pass
|
||||
@@ -0,0 +1,154 @@
|
||||
import os
|
||||
|
||||
import authlib.oauth2 as oauth2
|
||||
import authlib.oauth2.rfc6749.requests as requests
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
||||
from aws_lambda_powertools.event_handler.api_gateway import Response
|
||||
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
|
||||
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
|
||||
APIGatewayProxyEventV2,
|
||||
)
|
||||
from layercake.dateutils import now, ttl
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
||||
|
||||
from config import ACCESS_TOKEN_EXP_SECONDS, REFRESH_TOKEN_EXP_SECONDS
|
||||
|
||||
from .client import OAuth2Client
|
||||
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
|
||||
|
||||
DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
|
||||
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
|
||||
|
||||
|
||||
class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
persistence_layer: DynamoDBPersistenceLayer,
|
||||
) -> None:
|
||||
self._persistence_layer = persistence_layer
|
||||
|
||||
super().__init__(
|
||||
scopes_supported=(
|
||||
set(OAUTH2_SCOPES_SUPPORTED.split()) if OAUTH2_SCOPES_SUPPORTED else []
|
||||
)
|
||||
)
|
||||
|
||||
self.register_token_generator(
|
||||
'default',
|
||||
BearerTokenGenerator(
|
||||
access_token_generator=create_token_generator(42),
|
||||
refresh_token_generator=create_token_generator(48),
|
||||
),
|
||||
)
|
||||
|
||||
def save_token(
|
||||
self,
|
||||
token: dict,
|
||||
request: requests.OAuth2Request,
|
||||
) -> None:
|
||||
if not request.payload:
|
||||
raise ValueError('Missing request payload')
|
||||
|
||||
if not request.user:
|
||||
raise ValueError('Missing request user')
|
||||
|
||||
now_ = now()
|
||||
client_id = request.payload.client_id
|
||||
access_token = token['access_token']
|
||||
refresh_token = token.get('refresh_token')
|
||||
token_type = token['token_type']
|
||||
scope = token['scope']
|
||||
issued_at = int(now_.timestamp())
|
||||
|
||||
with self._persistence_layer.transact_writer() as transact:
|
||||
transact.put(
|
||||
item={
|
||||
'id': 'OAUTH2#TOKEN',
|
||||
'sk': f'ACCESS_TOKEN#{access_token}',
|
||||
'client_id': client_id,
|
||||
'token_type': token_type,
|
||||
'refresh_token': refresh_token,
|
||||
'scope': scope,
|
||||
'user': request.user,
|
||||
'expires_in': ACCESS_TOKEN_EXP_SECONDS,
|
||||
'issued_at': issued_at,
|
||||
'ttl': ttl(start_dt=now_, seconds=ACCESS_TOKEN_EXP_SECONDS),
|
||||
},
|
||||
)
|
||||
|
||||
if refresh_token:
|
||||
transact.put(
|
||||
item={
|
||||
'id': 'OAUTH2#TOKEN',
|
||||
'sk': f'REFRESH_TOKEN#{refresh_token}',
|
||||
'client_id': client_id,
|
||||
'token_type': token_type,
|
||||
'access_token': access_token,
|
||||
'scope': scope,
|
||||
'user': request.user,
|
||||
'expires_in': REFRESH_TOKEN_EXP_SECONDS,
|
||||
'issued_at': issued_at,
|
||||
'ttl': ttl(start_dt=now_, seconds=REFRESH_TOKEN_EXP_SECONDS),
|
||||
},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def query_client(self, client_id: str):
|
||||
client = self._persistence_layer.collection.get_item(
|
||||
KeyPair(pk='OAUTH2', sk=f'CLIENT_ID#{client_id}'),
|
||||
exc_cls=ClientNotFoundError,
|
||||
)
|
||||
|
||||
_, client_id = client.get(DYNAMODB_SORT_KEY, '').split('#')
|
||||
|
||||
return OAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client['client_secret'],
|
||||
scope=client['scope'],
|
||||
redirect_uris=client['redirect_uris'],
|
||||
response_types=client['response_types'],
|
||||
grant_types=client['grant_types'],
|
||||
token_endpoint_auth_method=client.get(
|
||||
'token_endpoint_auth_method', 'client_secret_basic'
|
||||
),
|
||||
)
|
||||
|
||||
def create_oauth2_request(
|
||||
self, request: APIGatewayProxyEventV2
|
||||
) -> APIGatewayOAuth2Request:
|
||||
return APIGatewayOAuth2Request(request)
|
||||
|
||||
def create_json_request(
|
||||
self, request: APIGatewayProxyEventV2
|
||||
) -> APIGatewayJsonRequest:
|
||||
return APIGatewayJsonRequest(request)
|
||||
|
||||
def handle_response(self, status: int, body, headers):
|
||||
return Response(
|
||||
status_code=status,
|
||||
body=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def send_signal(self, name: str, *args, **kwargs) -> None:
|
||||
# after_authenticate_client
|
||||
# when client is authenticated
|
||||
|
||||
# after_revoke_token
|
||||
# when token is revoked
|
||||
...
|
||||
|
||||
|
||||
class ClientNotFoundError(NotFoundError):
|
||||
def __init__(self, *_):
|
||||
super().__init__('Client not found')
|
||||
|
||||
|
||||
def create_token_generator(length: int = 42):
|
||||
def token_generator(*args, **kwargs):
|
||||
return generate_token(length)
|
||||
|
||||
return token_generator
|
||||
@@ -0,0 +1,60 @@
|
||||
import secrets
|
||||
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
ClientMixin,
|
||||
list_to_scope,
|
||||
scope_to_list,
|
||||
)
|
||||
|
||||
|
||||
class OAuth2Client(ClientMixin):
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
scope: str,
|
||||
redirect_uris: list,
|
||||
response_types: list,
|
||||
grant_types: list,
|
||||
token_endpoint_auth_method: str = 'client_secret_basic',
|
||||
) -> None:
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.scope = scope
|
||||
self.redirect_uris = redirect_uris
|
||||
self.response_types = response_types
|
||||
self.grant_types = grant_types
|
||||
self.token_endpoint_auth_method = token_endpoint_auth_method
|
||||
|
||||
def get_client_id(self):
|
||||
return self.client_id
|
||||
|
||||
def get_allowed_scope(self, scope) -> str:
|
||||
if not scope:
|
||||
return ''
|
||||
|
||||
allowed = set(self.scope.split())
|
||||
scopes = scope_to_list(scope)
|
||||
return list_to_scope([s for s in scopes if s in allowed])
|
||||
|
||||
def get_default_redirect_uri(self) -> str: # type: ignore
|
||||
if self.redirect_uris:
|
||||
return self.redirect_uris[0]
|
||||
|
||||
def check_response_type(self, response_type):
|
||||
return response_type in self.response_types
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def check_endpoint_auth_method(self, method, endpoint):
|
||||
if endpoint == 'token':
|
||||
return self.token_endpoint_auth_method == method
|
||||
|
||||
return True
|
||||
|
||||
def check_grant_type(self, grant_type):
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def check_client_secret(self, client_secret):
|
||||
return secrets.compare_digest(self.client_secret, client_secret)
|
||||
@@ -0,0 +1,65 @@
|
||||
from collections import defaultdict
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import authlib.oauth2.rfc6749.requests as requests
|
||||
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
|
||||
APIGatewayProxyEventV2,
|
||||
)
|
||||
|
||||
|
||||
class APIGatewayOAuth2Payload(requests.OAuth2Payload):
|
||||
def __init__(self, request: APIGatewayProxyEventV2):
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def decoded_body(self):
|
||||
# TODO
|
||||
body = parse_qs(self._request.decoded_body, keep_blank_values=True)
|
||||
return {k: v[0] if len(v) == 1 else v for k, v in body.items()}
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
"""Combines query string parameters and the request body"""
|
||||
return self._request.query_string_parameters | self.decoded_body
|
||||
|
||||
@property
|
||||
def datalist(self) -> dict[str, list]:
|
||||
values = defaultdict(list)
|
||||
|
||||
for k, v in self.data.items():
|
||||
values[k].extend([v])
|
||||
return values
|
||||
|
||||
|
||||
class APIGatewayJsonRequest(requests.JsonRequest):
|
||||
def __init__(self, request: APIGatewayProxyEventV2):
|
||||
uri = f'https://{request.request_context.domain_name}'
|
||||
|
||||
super().__init__(
|
||||
request.request_context.http.method,
|
||||
uri,
|
||||
request.headers,
|
||||
)
|
||||
|
||||
|
||||
class APIGatewayOAuth2Request(requests.OAuth2Request):
|
||||
def __init__(self, request: APIGatewayProxyEventV2):
|
||||
uri = f'https://{request.request_context.domain_name}'
|
||||
|
||||
super().__init__(
|
||||
request.request_context.http.method,
|
||||
uri,
|
||||
request.headers,
|
||||
)
|
||||
self._request = request
|
||||
self.payload = APIGatewayOAuth2Payload(request)
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
# @TODO
|
||||
return self._request.query_string_parameters
|
||||
|
||||
@property
|
||||
def form(self) -> dict[str, str]:
|
||||
# @TODO
|
||||
return self.payload.decoded_body
|
||||
@@ -0,0 +1,97 @@
|
||||
import time
|
||||
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
AuthorizationCodeMixin,
|
||||
ClientMixin,
|
||||
TokenMixin,
|
||||
)
|
||||
from layercake.dateutils import fromisoformat
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(AuthorizationCodeMixin):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
code: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
response_type: str,
|
||||
scope: str,
|
||||
code_challenge: str | None = None,
|
||||
code_challenge_method: str | None = None,
|
||||
nonce: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.code = code
|
||||
self.client_id = client_id
|
||||
self.redirect_uri = redirect_uri
|
||||
self.response_type = response_type
|
||||
self.scope = scope
|
||||
self.code_challenge = code_challenge
|
||||
self.code_challenge_method = code_challenge_method
|
||||
self.nonce = nonce
|
||||
|
||||
auth_time = fromisoformat(kwargs.get('created_at', '')) or now()
|
||||
self.auth_time = int(auth_time.timestamp())
|
||||
|
||||
def get_redirect_uri(self):
|
||||
return self.redirect_uri
|
||||
|
||||
def get_scope(self):
|
||||
return self.scope
|
||||
|
||||
def get_nonce(self):
|
||||
return self.nonce
|
||||
|
||||
def get_auth_time(self):
|
||||
return self.auth_time
|
||||
|
||||
def get_acr(self):
|
||||
return '0'
|
||||
|
||||
def get_amr(self):
|
||||
return []
|
||||
|
||||
|
||||
class OAuth2Token(TokenMixin):
|
||||
def __init__(
|
||||
self,
|
||||
user: dict,
|
||||
client_id: str,
|
||||
scope: str,
|
||||
expires_in: int,
|
||||
issued_at: int,
|
||||
access_token: str | None = None,
|
||||
refresh_token: str | None = None,
|
||||
**_,
|
||||
) -> None:
|
||||
self.user = user
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
self.expires_in = expires_in
|
||||
self.issued_at = issued_at
|
||||
self.access_token = access_token
|
||||
self.refresh_token = refresh_token
|
||||
|
||||
def get_user(self) -> dict:
|
||||
return self.user
|
||||
|
||||
def check_client(self, client: ClientMixin):
|
||||
return self.client_id == client.get_client_id()
|
||||
|
||||
def get_scope(self) -> str:
|
||||
return self.scope
|
||||
|
||||
def get_expires_in(self) -> int:
|
||||
return self.expires_in
|
||||
|
||||
def is_revoked(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
if not self.expires_in:
|
||||
return False
|
||||
|
||||
expires_at = self.issued_at + self.expires_in
|
||||
return expires_at < time.time()
|
||||
Reference in New Issue
Block a user