This commit is contained in:
2025-08-07 00:48:13 -03:00
parent ff25ade76e
commit 14312944fb
7 changed files with 180 additions and 57 deletions

View File

@@ -1,8 +1,10 @@
import os import os
import secrets
from collections import defaultdict from collections import defaultdict
from urllib.parse import parse_qs
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin from authlib.oauth2.rfc6749 import AuthorizationCodeMixin, ClientMixin, TokenMixin
from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest from authlib.oauth2.rfc6749.requests import JsonRequest as _JsonRequest
from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload from authlib.oauth2.rfc6749.requests import OAuth2Payload as _OAuth2Payload
from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request from authlib.oauth2.rfc6749.requests import OAuth2Request as _OAuth2Request
@@ -18,9 +20,16 @@ class OAuth2Payload(_OAuth2Payload):
def __init__(self, request: APIGatewayProxyEventV2): def __init__(self, request: APIGatewayProxyEventV2):
self._request = request self._request = request
@property
def decoded_body(self):
# TODO
body = parse_qs(self._request.decoded_body, keep_blank_values=True)
return {k: v[0] if len(v) == 1 else v for k, v in body.items()}
@property @property
def data(self): def data(self):
return self._request.query_string_parameters """Combines query string parameters and the request body"""
return self._request.query_string_parameters | self.decoded_body
@property @property
def datalist(self) -> dict[str, list]: def datalist(self) -> dict[str, list]:
@@ -51,19 +60,34 @@ class OAuth2Request(_OAuth2Request):
uri, uri,
request.headers, request.headers,
) )
self._request = request
self.payload = OAuth2Payload(request) self.payload = OAuth2Payload(request)
@property
def args(self):
return self._request.query_string_parameters
@property
def form(self) -> dict[str, str]:
return self.payload.decoded_body
class OAuth2Client(ClientMixin): class OAuth2Client(ClientMixin):
def __init__( def __init__(
self, self,
client_id: str, client_id: str,
client_secret: str,
redirect_uris: list, redirect_uris: list,
response_types: list, response_types: list,
grant_types: list,
token_endpoint_auth_method: str = 'client_secret_basic',
) -> None: ) -> None:
self.client_id = client_id self.client_id = client_id
self.client_secret = client_secret
self.redirect_uris = redirect_uris self.redirect_uris = redirect_uris
self.response_types = response_types self.response_types = response_types
self.grant_types = grant_types
self.token_endpoint_auth_method = token_endpoint_auth_method
def get_client_id(self): def get_client_id(self):
return self.client_id return self.client_id
@@ -78,10 +102,52 @@ class OAuth2Client(ClientMixin):
def check_redirect_uri(self, redirect_uri): def check_redirect_uri(self, redirect_uri):
return redirect_uri in self.redirect_uris return redirect_uri in self.redirect_uris
def check_endpoint_auth_method(self, method, endpoint):
if endpoint == 'token':
return self.token_endpoint_auth_method == method
return True
def check_grant_type(self, grant_type):
return grant_type in self.grant_types
def check_client_secret(self, client_secret):
return secrets.compare_digest(self.client_secret, client_secret)
class OAuth2Token(TokenMixin): ... class OAuth2Token(TokenMixin): ...
class AuthorizationCode(AuthorizationCodeMixin):
def __init__(
self,
user_id: str,
code: str,
client_id: str,
redirect_uri: str,
response_type: str,
scope: str,
code_challenge: str | None = None,
code_challenge_method: str | None = None,
nonce: str | None = None,
) -> None:
self.user_id = user_id
self.code = code
self.client_id = client_id
self.redirect_uri = redirect_uri
self.response_type = response_type
self.scope = scope
self.code_challenge = code_challenge
self.code_challenge_method = code_challenge_method
self.nonce = nonce
def get_redirect_uri(self):
return self.redirect_uri
def get_scope(self):
return self.scope
class AuthorizationServer(_AuthorizationServer): class AuthorizationServer(_AuthorizationServer):
def __init__(self, query_client, save_token) -> None: def __init__(self, query_client, save_token) -> None:
super().__init__( super().__init__(
@@ -109,3 +175,6 @@ class AuthorizationServer(_AuthorizationServer):
body=body, body=body,
headers=headers, headers=headers,
) )
def send_signal(self, name, *args, **kwargs):
pass

View File

@@ -1,15 +1,16 @@
from authlib.oauth2.rfc6749.grants import ( from authlib.oauth2.rfc6749 import TokenMixin, grants
AuthorizationCodeGrant as _AuthorizationCodeGrant, from authlib.oauth2.rfc7636 import CodeChallenge
)
from authlib.oidc.core import OpenIDCode as OpenIDCode_ from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
from aws_lambda_powertools.event_handler.exceptions import NotFoundError from aws_lambda_powertools.event_handler.exceptions import NotFoundError
from layercake.dateutils import now, ttl from layercake.dateutils import now, ttl
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from layercake.funcs import omit, pick
from boto3clients import dynamodb_client from boto3clients import dynamodb_client
from config import DYNAMODB_SORT_KEY, OAUTH2_TABLE from config import DYNAMODB_SORT_KEY, OAUTH2_TABLE
from integrations.apigateway_oauth2 import ( from integrations.apigateway_oauth2 import (
AuthorizationCode,
AuthorizationServer, AuthorizationServer,
OAuth2Client, OAuth2Client,
OAuth2Token, OAuth2Token,
@@ -28,6 +29,7 @@ DUMMY_JWT_CONFIG = {
def create_save_token_func(persistence_layer: DynamoDBPersistenceLayer): def create_save_token_func(persistence_layer: DynamoDBPersistenceLayer):
def save_token(token, request) -> OAuth2Token: def save_token(token, request) -> OAuth2Token:
print('save_token')
return OAuth2Token() return OAuth2Token()
return save_token return save_token
@@ -40,7 +42,10 @@ def create_query_client_func(persistence_layer: DynamoDBPersistenceLayer):
def query_client(client_id) -> OAuth2Client: def query_client(client_id) -> OAuth2Client:
client = persistence_layer.collection.get_item( client = persistence_layer.collection.get_item(
KeyPair('OAUTH2_CLIENT', f'CLIENT_ID#{client_id}'), KeyPair(
pk='OAUTH2_CLIENT',
sk=f'CLIENT_ID#{client_id}',
),
exc_cls=ClientNotFoundError, exc_cls=ClientNotFoundError,
) )
@@ -48,46 +53,16 @@ def create_query_client_func(persistence_layer: DynamoDBPersistenceLayer):
return OAuth2Client( return OAuth2Client(
client_id=client_id, client_id=client_id,
client_secret=client['secret'],
redirect_uris=client['redirect_uris'], redirect_uris=client['redirect_uris'],
response_types=client['response_types'], response_types=client['response_types'],
grant_types=client['grant_types'],
token_endpoint_auth_method=client['token_endpoint_auth_method'],
) )
return query_client return query_client
def save_authorization_code(code, request):
data: dict = request.payload.data # type: ignore
user: dict = request.user # type: ignore
nonce: str | None = data.get('nonce')
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=15)
with oauth2_layer.transact_writer() as transact:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{request.payload.client_id}',
'sk': f'CODE#{code}',
'redirect_uri': request.payload.redirect_uri, # type: ignore
'scope': request.payload.scope, # type: ignore
'user_id': user['id'],
'nonce': nonce,
'created_at': now_,
'ttl': ttl_,
},
)
if nonce:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{request.payload.client_id}',
'sk': f'NONCE#{nonce}',
'code': code,
'created_at': now_,
'ttl': ttl_,
},
)
class OpenIDCode(OpenIDCode_): class OpenIDCode(OpenIDCode_):
def exists_nonce(self, nonce, request): def exists_nonce(self, nonce, request):
nonce_ = oauth2_layer.get_item( nonce_ = oauth2_layer.get_item(
@@ -110,18 +85,93 @@ class OpenIDCode(OpenIDCode_):
).filter(scope) ).filter(scope)
class AuthorizationCodeGrant(_AuthorizationCodeGrant): class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [ TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none']
'client_secret_basic',
'client_secret_post',
]
def save_authorization_code(self, code: str, request): def save_authorization_code(self, code: str, request):
return save_authorization_code(code, request) client_id: str = request.payload.client_id # type: ignore
data: dict = request.payload.data # type: ignore
user: dict = request.user # type: ignore
nonce: str | None = data.get('nonce')
code_challenge: str | None = data.get('code_challenge')
code_challenge_method: str | None = data.get('code_challenge_method')
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=15)
with oauth2_layer.transact_writer() as transact:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}',
'sk': f'CODE#{code}',
'redirect_uri': request.payload.redirect_uri, # type: ignore
'scope': request.payload.scope, # type: ignore
'user_id': user['id'],
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': code_challenge_method,
'created_at': now_,
'ttl': ttl_,
},
)
if nonce:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}',
'sk': f'NONCE#{nonce}',
'code': code,
'created_at': now_,
'ttl': ttl_,
},
)
def query_authorization_code(self, code, client):
client_id = client.get_client_id()
auth_code = oauth2_layer.get_item(
KeyPair(
pk=f'OAUTH2_CODE#CLIENT_ID#{client_id}',
sk=f'CODE#{code}',
),
)
return AuthorizationCode(
client_id=client_id,
code=code,
**omit(('id', 'sk'), auth_code),
)
def delete_authorization_code(self, authorization_code):
print('authorization_code')
def authenticate_user(self, authorization_code):
user = oauth2_layer.get_item(
KeyPair(
pk=authorization_code.user_id,
sk='0',
),
)
return pick(('id', 'name', 'email'), user)
authorization = AuthorizationServer( class RefreshTokenGrant(grants.RefreshTokenGrant):
INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin: ...
def authenticate_user(self, refresh_token): ...
def revoke_old_credential(self, refresh_token: TokenMixin) -> None: ...
server = AuthorizationServer(
query_client=create_query_client_func(oauth2_layer), query_client=create_query_client_func(oauth2_layer),
save_token=create_save_token_func(oauth2_layer), save_token=create_save_token_func(oauth2_layer),
) )
authorization.register_grant(AuthorizationCodeGrant, [OpenIDCode(require_nonce=False)]) server.register_grant(
AuthorizationCodeGrant,
[
OpenIDCode(require_nonce=False),
CodeChallenge(required=True),
],
)
server.register_grant(RefreshTokenGrant)

View File

@@ -9,7 +9,7 @@ from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.api_gateway import Router from aws_lambda_powertools.event_handler.api_gateway import Router
from jose_ import verify_jwt from jose_ import verify_jwt
from oauth2 import authorization from oauth2 import server
router = Router() router = Router()
logger = Logger(__name__) logger = Logger(__name__)
@@ -51,7 +51,7 @@ def authorize():
) )
try: try:
grant = authorization.get_consent_grant( grant = server.get_consent_grant(
request=router.current_event, request=router.current_event,
end_user={'id': user['sub']}, end_user={'id': user['sub']},
) )
@@ -60,7 +60,7 @@ def authorize():
return dict(err.get_body()) return dict(err.get_body())
try: try:
return authorization.create_authorization_response( return server.create_authorization_response(
request=router.current_event, request=router.current_event,
grant_user={'id': user['sub']}, grant_user={'id': user['sub']},
grant=grant, grant=grant,

View File

@@ -1,8 +1,10 @@
from aws_lambda_powertools.event_handler.api_gateway import Router from aws_lambda_powertools.event_handler.api_gateway import Router
from oauth2 import server
router = Router() router = Router()
@router.get('/token') @router.post('/token')
def token(): def issue_token():
return {} return server.create_token_response(router.current_event)

View File

@@ -85,7 +85,7 @@ Resources:
Type: HttpApi Type: HttpApi
Properties: Properties:
Path: /token Path: /token
Method: GET Method: POST
ApiId: !Ref HttpApi ApiId: !Ref HttpApi
UserInfo: UserInfo:
Type: HttpApi Type: HttpApi

View File

@@ -52,4 +52,6 @@ def test_authorize(
':sk': 'CODE', ':sk': 'CODE',
}, },
) )
# One item was added from seeds
assert len(r['items']) == 2 assert len(r['items']) == 2

View File

@@ -1,6 +1,6 @@
// OAuth2 // OAuth2
{"id": "OAUTH2_CLIENT", "sk": "CLIENT_ID#d72d4005-1fa7-4430-9754-80d5e2487bb6", "secret": "1nFD8alDbGHgc3g1RLY960xyRJVee0SlMoIB0MUlSuiJy28W", "name": "pytest", "scope": "openid profile", "redirect_uris": ["https://localhost/callback"], "response_types": ["code"], "grant_types": ["authorization_code"]} {"id": "OAUTH2_CLIENT", "sk": "CLIENT_ID#d72d4005-1fa7-4430-9754-80d5e2487bb6", "secret": "1nFD8alDbGHgc3g1RLY960xyRJVee0SlMoIB0MUlSuiJy28W", "name": "pytest", "scope": "openid profile", "redirect_uris": ["https://localhost/callback"], "response_types": ["code"], "grant_types": ["authorization_code", "refresh_token"], "token_endpoint_auth_method": "none"}
{"id": "OAUTH2_CODE#CLIENT_ID#d72d4005-1fa7-4430-9754-80d5e2487bb6", "sk": "CODE#kyqp3oSuRFTfuBaCmq3XOgGWg67l42Kt3D6xPEj7Yd3MLdi9", "redirect_uri": "https://localhost/callback", "user_id": "0cb0ce87-9df6-40c1-9fa7-7dfdafd7910e", "nonce": "123", "scope": "openid profile email"} {"id": "OAUTH2_CODE#CLIENT_ID#d72d4005-1fa7-4430-9754-80d5e2487bb6", "sk": "CODE#kyqp3oSuRFTfuBaCmq3XOgGWg67l42Kt3D6xPEj7Yd3MLdi9", "redirect_uri": "https://localhost/callback", "user_id": "357db1c5-7442-4075-98a3-fbe5c938a419", "nonce": null, "scope": "openid profile email", "response_type": "code", "code_challenge": "ejYEIGKQUgMnNh4eV0sftb0hXdLwkvKm6OHXRYvC--I", "code_challenge_method": "S256"}
// Post-migration: uncomment the following line // Post-migration: uncomment the following line
// {"id": "EMAIL", "sk": "sergio@somosbeta.com.br", "user_id": "357db1c5-7442-4075-98a3-fbe5c938a419"} // {"id": "EMAIL", "sk": "sergio@somosbeta.com.br", "user_id": "357db1c5-7442-4075-98a3-fbe5c938a419"}
@@ -8,4 +8,4 @@
// User data // User data
{"id": "357db1c5-7442-4075-98a3-fbe5c938a419", "sk": "0", "name": "Sérgio R Siqueira", "email": "sergio@somosbeta.com.br"} {"id": "357db1c5-7442-4075-98a3-fbe5c938a419", "sk": "0", "name": "Sérgio R Siqueira", "email": "sergio@somosbeta.com.br"}
{"id": "357db1c5-7442-4075-98a3-fbe5c938a419", "sk": "PASSWORD", "hash": "$pbkdf2-sha256$29000$IuTcm7M2BiAEgPB.b.3dGw$d8xVCbx8zxg7MeQBrOvCOgniiilsIHEMHzoH/OXftLQ"} {"id": "357db1c5-7442-4075-98a3-fbe5c938a419", "sk": "PASSWORD", "hash": "$pbkdf2-sha256$29000$IuTcm7M2BiAEgPB.b.3dGw$d8xVCbx8zxg7MeQBrOvCOgniiilsIHEMHzoH/OXftLQ"}