This commit is contained in:
2025-08-07 00:48:13 -03:00
parent ff25ade76e
commit 14312944fb
7 changed files with 180 additions and 57 deletions

View File

@@ -1,15 +1,16 @@
from authlib.oauth2.rfc6749.grants import (
AuthorizationCodeGrant as _AuthorizationCodeGrant,
)
from authlib.oauth2.rfc6749 import TokenMixin, grants
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
from layercake.dateutils import now, ttl
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from layercake.funcs import omit, pick
from boto3clients import dynamodb_client
from config import DYNAMODB_SORT_KEY, OAUTH2_TABLE
from integrations.apigateway_oauth2 import (
AuthorizationCode,
AuthorizationServer,
OAuth2Client,
OAuth2Token,
@@ -28,6 +29,7 @@ DUMMY_JWT_CONFIG = {
def create_save_token_func(persistence_layer: DynamoDBPersistenceLayer):
def save_token(token, request) -> OAuth2Token:
print('save_token')
return OAuth2Token()
return save_token
@@ -40,7 +42,10 @@ def create_query_client_func(persistence_layer: DynamoDBPersistenceLayer):
def query_client(client_id) -> OAuth2Client:
client = persistence_layer.collection.get_item(
KeyPair('OAUTH2_CLIENT', f'CLIENT_ID#{client_id}'),
KeyPair(
pk='OAUTH2_CLIENT',
sk=f'CLIENT_ID#{client_id}',
),
exc_cls=ClientNotFoundError,
)
@@ -48,46 +53,16 @@ def create_query_client_func(persistence_layer: DynamoDBPersistenceLayer):
return OAuth2Client(
client_id=client_id,
client_secret=client['secret'],
redirect_uris=client['redirect_uris'],
response_types=client['response_types'],
grant_types=client['grant_types'],
token_endpoint_auth_method=client['token_endpoint_auth_method'],
)
return query_client
def save_authorization_code(code, request):
data: dict = request.payload.data # type: ignore
user: dict = request.user # type: ignore
nonce: str | None = data.get('nonce')
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=15)
with oauth2_layer.transact_writer() as transact:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{request.payload.client_id}',
'sk': f'CODE#{code}',
'redirect_uri': request.payload.redirect_uri, # type: ignore
'scope': request.payload.scope, # type: ignore
'user_id': user['id'],
'nonce': nonce,
'created_at': now_,
'ttl': ttl_,
},
)
if nonce:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{request.payload.client_id}',
'sk': f'NONCE#{nonce}',
'code': code,
'created_at': now_,
'ttl': ttl_,
},
)
class OpenIDCode(OpenIDCode_):
def exists_nonce(self, nonce, request):
nonce_ = oauth2_layer.get_item(
@@ -110,18 +85,93 @@ class OpenIDCode(OpenIDCode_):
).filter(scope)
class AuthorizationCodeGrant(_AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
]
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none']
def save_authorization_code(self, code: str, request):
return save_authorization_code(code, request)
client_id: str = request.payload.client_id # type: ignore
data: dict = request.payload.data # type: ignore
user: dict = request.user # type: ignore
nonce: str | None = data.get('nonce')
code_challenge: str | None = data.get('code_challenge')
code_challenge_method: str | None = data.get('code_challenge_method')
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=15)
with oauth2_layer.transact_writer() as transact:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}',
'sk': f'CODE#{code}',
'redirect_uri': request.payload.redirect_uri, # type: ignore
'scope': request.payload.scope, # type: ignore
'user_id': user['id'],
'nonce': nonce,
'code_challenge': code_challenge,
'code_challenge_method': code_challenge_method,
'created_at': now_,
'ttl': ttl_,
},
)
if nonce:
transact.put(
item={
'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}',
'sk': f'NONCE#{nonce}',
'code': code,
'created_at': now_,
'ttl': ttl_,
},
)
def query_authorization_code(self, code, client):
client_id = client.get_client_id()
auth_code = oauth2_layer.get_item(
KeyPair(
pk=f'OAUTH2_CODE#CLIENT_ID#{client_id}',
sk=f'CODE#{code}',
),
)
return AuthorizationCode(
client_id=client_id,
code=code,
**omit(('id', 'sk'), auth_code),
)
def delete_authorization_code(self, authorization_code):
print('authorization_code')
def authenticate_user(self, authorization_code):
user = oauth2_layer.get_item(
KeyPair(
pk=authorization_code.user_id,
sk='0',
),
)
return pick(('id', 'name', 'email'), user)
authorization = AuthorizationServer(
class RefreshTokenGrant(grants.RefreshTokenGrant):
INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin: ...
def authenticate_user(self, refresh_token): ...
def revoke_old_credential(self, refresh_token: TokenMixin) -> None: ...
server = AuthorizationServer(
query_client=create_query_client_func(oauth2_layer),
save_token=create_save_token_func(oauth2_layer),
)
authorization.register_grant(AuthorizationCodeGrant, [OpenIDCode(require_nonce=False)])
server.register_grant(
AuthorizationCodeGrant,
[
OpenIDCode(require_nonce=False),
CodeChallenge(required=True),
],
)
server.register_grant(RefreshTokenGrant)