This commit is contained in:
2025-10-31 15:15:31 -03:00
parent 460748492a
commit 26a87b8445
5 changed files with 69 additions and 28 deletions

View File

@@ -1,5 +1,3 @@
from os import rename
from authlib.common.security import generate_token from authlib.common.security import generate_token
from authlib.common.urls import add_params_to_uri from authlib.common.urls import add_params_to_uri
from authlib.jose import JsonWebKey from authlib.jose import JsonWebKey
@@ -13,7 +11,6 @@ from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
from aws_lambda_powertools import Logger from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Response from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
from layercake.dateutils import now, ttl from layercake.dateutils import now, ttl
from layercake.dynamodb import ( from layercake.dynamodb import (
DynamoDBPersistenceLayer, DynamoDBPersistenceLayer,
@@ -42,9 +39,11 @@ private_key = read_file_path('private.pem')
private_jwk = JsonWebKey.import_key(private_key) private_jwk = JsonWebKey.import_key(private_key)
# https://docs.authlib.org/en/v0.12/specs/rfc6750.html#authlib.oauth2.rfc6750.BearerToken.GRANT_TYPES_EXPIRES_IN
GRANT_TYPES_EXPIRES_IN = { GRANT_TYPES_EXPIRES_IN = {
'authorization_code': 60 * 3, # 3 minutes 'authorization_code': 60 * 3, # 3 minutes
'refresh_token': 3600, # 1 hour 'refresh_token': 60 * 3, # 3 minutes
# 'refresh_token': 3600, # 1 hour
} }
@@ -83,6 +82,17 @@ class OpenIDCode(OpenIDCode_):
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
"""The authorization code grant type is used to obtain both access tokens and
refresh tokens and is optimized for confidential clients.
Since this is a redirection-based flow, the client must be capable of
interacting with the resource owners user-agent (typically a web browser)
and capable of receiving incoming requests (via redirection) from the
authorization server:
- https://docs.authlib.org/en/v0.12/specs/rfc6749.html#authlib.oauth2.rfc6749.grants.AuthorizationCodeGrant
"""
TOKEN_ENDPOINT_AUTH_METHODS = [ TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic', 'client_secret_basic',
'client_secret_post', 'client_secret_post',
@@ -206,11 +216,6 @@ class TokenExchangeGrant(grants.BaseGrant):
raise NotImplementedError() raise NotImplementedError()
class RefreshTokenNotFoundError(NotFoundError):
def __init__(self, *_):
super().__init__('Refresh token not found')
class RefreshTokenGrant(grants.RefreshTokenGrant): class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [ TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic', 'client_secret_basic',
@@ -221,17 +226,22 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
# The authorization server MAY issue a new refresh token # The authorization server MAY issue a new refresh token
INCLUDE_NEW_REFRESH_TOKEN = True INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin: def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin:
"""Get token information with refresh_token string.""" """Get token information with refresh_token string."""
token = dyn.collection.get_item( token = dyn.collection.get_item(
KeyPair( KeyPair(
pk='OAUTH2#TOKEN', pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{refresh_token}', sk=f'REFRESH_TOKEN#{refresh_token}',
), ),
exc_cls=RefreshTokenNotFoundError, raise_on_error=False,
default=False,
) )
logger.info('Refresh token retrieved', token=token) if not token:
logger.info('Refresh token not found ', refresh_token=token)
return None # type: ignore
logger.info('Refresh token retrieved', refresh_token=token)
return OAuth2Token( return OAuth2Token(
expires_in=int(token['expires_in']), expires_in=int(token['expires_in']),
@@ -249,13 +259,14 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
issuing a new refresh token to the client.""" issuing a new refresh token to the client."""
token = getattr(refresh_token, 'refresh_token', None) token = getattr(refresh_token, 'refresh_token', None)
logger.debug('Revoking old refresh token', refresh_token=token)
user = refresh_token.get_user() user = refresh_token.get_user()
if not token: if not token:
return None return None
with dyn.transact_writer() as transact: with dyn.transact_writer() as transact:
logger.debug('Revoking old refresh token', refresh_token=token)
transact.delete( transact.delete(
key=KeyPair( key=KeyPair(
pk='OAUTH2#TOKEN', pk='OAUTH2#TOKEN',
@@ -276,24 +287,24 @@ class RevocationEndpoint(rfc7009.RevocationEndpoint):
token_string: str, token_string: str,
token_type_hint: str | None = None, token_type_hint: str | None = None,
): ):
t = dyn.collection.get_items( result = dyn.collection.get_items(
TransactKey('OAUTH2#TOKEN') TransactKey('OAUTH2#TOKEN')
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token') + SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'), + SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
flatten_top=False, flatten_top=False,
) )
if not t: if not result:
return None return None
logger.debug('Tokens retrieved', result=t) logger.debug('Tokens retrieved', result=result)
if not token_type_hint: if not token_type_hint:
token_type_hint = ( token_type_hint = (
'refresh_token' if 'refresh_token' in t else 'access_token' 'refresh_token' if 'refresh_token' in result else 'access_token'
) )
token = t[token_type_hint] token = result[token_type_hint]
return OAuth2Token( return OAuth2Token(
expires_in=int(token['expires_in']), expires_in=int(token['expires_in']),
@@ -319,24 +330,25 @@ class RevocationEndpoint(rfc7009.RevocationEndpoint):
- https://docs.authlib.org/en/latest/specs/rfc7009.html#authlib.oauth2.rfc7009.RevocationEndpoint.revoke_token - https://docs.authlib.org/en/latest/specs/rfc7009.html#authlib.oauth2.rfc7009.RevocationEndpoint.revoke_token
""" """
user_id = token.user['id'] user_id = token.user['id']
r = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION')) result = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
with dyn.transact_writer() as transact: with dyn.transact_writer() as transact:
# Revoke all sessions, access tokens, and refresh tokens # Revoke all sessions, access tokens, and refresh tokens
for x in r['items']: for item in result['items']:
pk, sk = x['id'], x['sk'] pk, sk = item['id'], item['sk']
*_, kind, idx = sk.split('#') # Key pattern: SESSION#{}#{}
*_, kind, token_ = sk.split('#')
transact.delete(key=KeyPair(pk, sk)) transact.delete(key=KeyPair(pk, sk))
transact.delete( transact.delete(
key=KeyPair( key=KeyPair(
pk='SESSION', pk='SESSION',
sk=idx, sk=token_,
) )
if kind == 'SESSION' if kind == 'SESSION'
else KeyPair( else KeyPair(
pk='OAUTH2#TOKEN', pk='OAUTH2#TOKEN',
sk=f'{kind}#{idx}', sk=f'{kind}#{token_}',
) )
) )

View File

@@ -9,6 +9,4 @@ router = Router()
@router.post('/token') @router.post('/token')
def issue_token(): def issue_token():
response = server.create_token_response(router.current_event) return server.create_token_response(router.current_event)
logger.debug(getattr(response, 'body', None))
return response

View File

@@ -1,4 +1,5 @@
import json import json
import pprint
from base64 import b64encode from base64 import b64encode
from http import HTTPMethod, HTTPStatus from http import HTTPMethod, HTTPStatus
from urllib.parse import urlencode from urllib.parse import urlencode

View File

@@ -1,4 +1,5 @@
import json import json
import pprint
from base64 import b64encode from base64 import b64encode
from http import HTTPMethod, HTTPStatus from http import HTTPMethod, HTTPStatus
from urllib.parse import urlencode from urllib.parse import urlencode
@@ -125,3 +126,32 @@ def test_refresh_token(
) )
assert len(r['items']) == 2 assert len(r['items']) == 2
def test_refresh_token_invalid_grant(
app,
seeds,
dynamodb_persistence_layer: DynamoDBPersistenceLayer,
http_api_proxy: HttpApiProxy,
lambda_context: LambdaContext,
):
r = app.lambda_handler(
http_api_proxy(
raw_path='/token',
method=HTTPMethod.POST,
headers={
'Authorization': f'Basic {AUTH}',
'Content-Type': 'application/x-www-form-urlencoded',
},
body=urlencode(
{
'grant_type': 'refresh_token',
'refresh_token': '1234',
'client_id': CLIENT_ID,
}
),
),
lambda_context,
)
assert r['statusCode'] == 400

View File

@@ -507,7 +507,7 @@ wheels = [
[[package]] [[package]]
name = "layercake" name = "layercake"
version = "0.11.0" version = "0.11.1"
source = { directory = "../layercake" } source = { directory = "../layercake" }
dependencies = [ dependencies = [
{ name = "arnparse" }, { name = "arnparse" },