This commit is contained in:
2025-08-07 00:48:13 -03:00
parent ff25ade76e
commit 14312944fb
7 changed files with 180 additions and 57 deletions

View File

@@ -1,8 +1,10 @@
import os
import secrets
from collections import defaultdict
from urllib.parse import parse_qs
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin
from authlib.oauth2.rfc6749 import AuthorizationCodeMixin, ClientMixin, TokenMixin
from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest
from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload
from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request
@@ -18,9 +20,16 @@ class OAuth2Payload(_OAuth2Payload):
def __init__(self, request: APIGatewayProxyEventV2):
self._request = request
@property
def decoded_body(self):
# TODO
body = parse_qs(self._request.decoded_body, keep_blank_values=True)
return {k: v[0] if len(v) == 1 else v for k, v in body.items()}
@property
def data(self):
return self._request.query_string_parameters
"""Combines query string parameters and the request body"""
return self._request.query_string_parameters | self.decoded_body
@property
def datalist(self) -> dict[str, list]:
@@ -51,19 +60,34 @@ class OAuth2Request(_OAuth2Request):
uri,
request.headers,
)
self._request = request
self.payload = OAuth2Payload(request)
@property
def args(self):
return self._request.query_string_parameters
@property
def form(self) -> dict[str, str]:
return self.payload.decoded_body
class OAuth2Client(ClientMixin):
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uris: list,
response_types: list,
grant_types: list,
token_endpoint_auth_method: str = 'client_secret_basic',
) -> None:
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uris = redirect_uris
self.response_types = response_types
self.grant_types = grant_types
self.token_endpoint_auth_method = token_endpoint_auth_method
def get_client_id(self):
return self.client_id
@@ -78,10 +102,52 @@ class OAuth2Client(ClientMixin):
def check_redirect_uri(self, redirect_uri):
return redirect_uri in self.redirect_uris
def check_endpoint_auth_method(self, method, endpoint):
if endpoint == 'token':
return self.token_endpoint_auth_method == method
return True
def check_grant_type(self, grant_type):
return grant_type in self.grant_types
def check_client_secret(self, client_secret):
return secrets.compare_digest(self.client_secret, client_secret)
class OAuth2Token(TokenMixin): ...
class AuthorizationCode(AuthorizationCodeMixin):
def __init__(
self,
user_id: str,
code: str,
client_id: str,
redirect_uri: str,
response_type: str,
scope: str,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
nonce: str | None = None,
) -> None:
self.user_id = user_id
self.code = code
self.client_id = client_id
self.redirect_uri = redirect_uri
self.response_type = response_type
self.scope = scope
self.code_challenge = code_challenge
self.code_challenge_method = code_challenge_method
self.nonce = nonce
def get_redirect_uri(self):
return self.redirect_uri
def get_scope(self):
return self.scope
class AuthorizationServer(_AuthorizationServer):
def __init__(self, query_client, save_token) -> None:
super().__init__(
@@ -109,3 +175,6 @@ class AuthorizationServer(_AuthorizationServer):
body=body,
headers=headers,
)
def send_signal(self, name, *args, **kwargs):
pass