181 lines
5.3 KiB
Python
181 lines
5.3 KiB
Python
import os
|
|
import secrets
|
|
from collections import defaultdict
|
|
from urllib.parse import parse_qs
|
|
|
|
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
|
|
from authlib.oauth2.rfc6749 import AuthorizationCodeMixin, ClientMixin, TokenMixin
|
|
from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest
|
|
from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload
|
|
from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request
|
|
from aws_lambda_powertools.event_handler.api_gateway import Response
|
|
from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import (
|
|
APIGatewayProxyEventV2,
|
|
)
|
|
|
|
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
|
|
|
|
|
|
class OAuth2Payload(_OAuth2Payload):
|
|
def __init__(self, request: APIGatewayProxyEventV2):
|
|
self._request = request
|
|
|
|
@property
|
|
def decoded_body(self):
|
|
# TODO
|
|
body = parse_qs(self._request.decoded_body, keep_blank_values=True)
|
|
return {k: v[0] if len(v) == 1 else v for k, v in body.items()}
|
|
|
|
@property
|
|
def data(self):
|
|
"""Combines query string parameters and the request body"""
|
|
return self._request.query_string_parameters | self.decoded_body
|
|
|
|
@property
|
|
def datalist(self) -> dict[str, list]:
|
|
values = defaultdict(list)
|
|
|
|
for k, v in self.data.items():
|
|
values[k].extend([v])
|
|
return values
|
|
|
|
|
|
class JsonRequest(_JsonRequest):
|
|
def __init__(self, request: APIGatewayProxyEventV2):
|
|
uri = f'https://{request.request_context.domain_name}'
|
|
|
|
super().__init__(
|
|
request.request_context.http.method,
|
|
uri,
|
|
request.headers,
|
|
)
|
|
|
|
|
|
class OAuth2Request(_OAuth2Request):
|
|
def __init__(self, request: APIGatewayProxyEventV2):
|
|
uri = f'https://{request.request_context.domain_name}'
|
|
|
|
super().__init__(
|
|
request.request_context.http.method,
|
|
uri,
|
|
request.headers,
|
|
)
|
|
self._request = request
|
|
self.payload = OAuth2Payload(request)
|
|
|
|
@property
|
|
def args(self):
|
|
return self._request.query_string_parameters
|
|
|
|
@property
|
|
def form(self) -> dict[str, str]:
|
|
return self.payload.decoded_body
|
|
|
|
|
|
class OAuth2Client(ClientMixin):
|
|
def __init__(
|
|
self,
|
|
client_id: str,
|
|
client_secret: str,
|
|
redirect_uris: list,
|
|
response_types: list,
|
|
grant_types: list,
|
|
token_endpoint_auth_method: str = 'client_secret_basic',
|
|
) -> None:
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.redirect_uris = redirect_uris
|
|
self.response_types = response_types
|
|
self.grant_types = grant_types
|
|
self.token_endpoint_auth_method = token_endpoint_auth_method
|
|
|
|
def get_client_id(self):
|
|
return self.client_id
|
|
|
|
def get_default_redirect_uri(self) -> str: # type: ignore
|
|
if self.redirect_uris:
|
|
return self.redirect_uris[0]
|
|
|
|
def check_response_type(self, response_type):
|
|
return response_type in self.response_types
|
|
|
|
def check_redirect_uri(self, redirect_uri):
|
|
return redirect_uri in self.redirect_uris
|
|
|
|
def check_endpoint_auth_method(self, method, endpoint):
|
|
if endpoint == 'token':
|
|
return self.token_endpoint_auth_method == method
|
|
|
|
return True
|
|
|
|
def check_grant_type(self, grant_type):
|
|
return grant_type in self.grant_types
|
|
|
|
def check_client_secret(self, client_secret):
|
|
return secrets.compare_digest(self.client_secret, client_secret)
|
|
|
|
|
|
class OAuth2Token(TokenMixin): ...
|
|
|
|
|
|
class AuthorizationCode(AuthorizationCodeMixin):
|
|
def __init__(
|
|
self,
|
|
user_id: str,
|
|
code: str,
|
|
client_id: str,
|
|
redirect_uri: str,
|
|
response_type: str,
|
|
scope: str,
|
|
code_challenge: str | None = None,
|
|
code_challenge_method: str | None = None,
|
|
nonce: str | None = None,
|
|
) -> None:
|
|
self.user_id = user_id
|
|
self.code = code
|
|
self.client_id = client_id
|
|
self.redirect_uri = redirect_uri
|
|
self.response_type = response_type
|
|
self.scope = scope
|
|
self.code_challenge = code_challenge
|
|
self.code_challenge_method = code_challenge_method
|
|
self.nonce = nonce
|
|
|
|
def get_redirect_uri(self):
|
|
return self.redirect_uri
|
|
|
|
def get_scope(self):
|
|
return self.scope
|
|
|
|
|
|
class AuthorizationServer(_AuthorizationServer):
|
|
def __init__(self, query_client, save_token) -> None:
|
|
super().__init__(
|
|
scopes_supported=OAUTH2_SCOPES_SUPPORTED,
|
|
)
|
|
|
|
self._query_client = query_client
|
|
self._save_token = save_token
|
|
|
|
def save_token(self, token, request):
|
|
return self._save_token(token, request)
|
|
|
|
def query_client(self, client_id: str):
|
|
return self._query_client(client_id)
|
|
|
|
def create_oauth2_request(self, request: APIGatewayProxyEventV2) -> OAuth2Request:
|
|
return OAuth2Request(request)
|
|
|
|
def create_json_request(self, request: APIGatewayProxyEventV2) -> JsonRequest:
|
|
return JsonRequest(request)
|
|
|
|
def handle_response(self, status, body, headers):
|
|
return Response(
|
|
status_code=status,
|
|
body=body,
|
|
headers=headers,
|
|
)
|
|
|
|
def send_signal(self, name, *args, **kwargs):
|
|
pass
|