add revoke

This commit is contained in:
2025-09-17 16:51:35 -03:00
parent 207231cff6
commit b2303fc60a
18 changed files with 411 additions and 140 deletions

View File

@@ -10,6 +10,7 @@ from aws_lambda_powertools.utilities.typing import LambdaContext
from routes.authorize import router as authorize
from routes.jwks import router as jwks
from routes.openid_configuration import router as openid_configuration
from routes.revoke import router as revoke
from routes.session import router as session
from routes.token import router as token
from routes.userinfo import router as userinfo
@@ -22,6 +23,7 @@ app.include_router(authorize)
app.include_router(jwks)
app.include_router(token)
app.include_router(userinfo)
app.include_router(revoke)
app.include_router(openid_configuration)

View File

@@ -68,15 +68,30 @@ class AuthorizationServer(oauth2.AuthorizationServer):
raise ValueError('Missing request user')
now_ = now()
client_id = request.payload.client_id
client_id = (
request.client.get_client_id()
if request.client
else request.payload.client_id
)
user_id = request.user.get('id')
access_token = token['access_token']
refresh_token = token.get('refresh_token')
token_type = token['token_type']
scope = token['scope']
expires_in = int(token['expires_in'])
issued_at = int(now_.timestamp())
access_token_ttl = ttl(start_dt=now_, seconds=expires_in)
refresh_token_ttl = ttl(start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN)
with self._persistence_layer.transact_writer() as transact:
transact.put(
item={
'id': user_id,
'sk': f'SESSION#ACCESS_TOKEN#{access_token}',
'ttl': access_token_ttl,
'created_at': now_,
}
)
transact.put(
item={
'id': 'OAUTH2#TOKEN',
@@ -88,11 +103,19 @@ class AuthorizationServer(oauth2.AuthorizationServer):
'user': request.user,
'expires_in': expires_in,
'issued_at': issued_at,
'ttl': ttl(start_dt=now_, seconds=expires_in),
'ttl': access_token_ttl,
},
)
if refresh_token:
transact.put(
item={
'id': user_id,
'sk': f'SESSION#REFRESH_TOKEN#{refresh_token}',
'ttl': access_token_ttl,
'created_at': now_,
}
)
transact.put(
item={
'id': 'OAUTH2#TOKEN',
@@ -104,9 +127,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
'user': request.user,
'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
'issued_at': issued_at,
'ttl': ttl(
start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN
),
'ttl': refresh_token_ttl,
},
)
@@ -114,7 +135,10 @@ class AuthorizationServer(oauth2.AuthorizationServer):
def query_client(self, client_id: str):
client = self._persistence_layer.collection.get_item(
KeyPair(pk='OAUTH2', sk=f'CLIENT_ID#{client_id}'),
KeyPair(
pk='OAUTH2',
sk=f'CLIENT_ID#{client_id}',
),
exc_cls=ClientNotFoundError,
)
@@ -125,9 +149,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
redirect_uris=client['redirect_uris'],
response_types=client['response_types'],
grant_types=client['grant_types'],
token_endpoint_auth_method=client.get(
'token_endpoint_auth_method', 'client_secret_basic'
),
token_endpoint_auth_method=client.get('token_endpoint_auth_method', 'none'),
)
def create_oauth2_request(

View File

@@ -1,5 +1,5 @@
from authlib.common.urls import add_params_to_uri
from authlib.oauth2 import OAuth2Request, rfc9207
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oidc.core import OpenIDCode as OpenIDCode_
@@ -8,7 +8,12 @@ from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
from layercake.dateutils import now, ttl
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from layercake.dynamodb import (
DynamoDBPersistenceLayer,
KeyPair,
SortKey,
TransactKey,
)
from layercake.funcs import omit, pick
from boto3clients import dynamodb_client
@@ -22,7 +27,7 @@ from integrations.apigateway_oauth2.tokens import (
)
logger = Logger(__name__)
oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
class OpenIDCode(OpenIDCode_):
@@ -30,7 +35,7 @@ class OpenIDCode(OpenIDCode_):
if not request.payload:
raise ValueError('Missing request payload')
nonce_ = oauth2_layer.get_item(
nonce_ = dyn.get_item(
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
)
@@ -65,6 +70,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
code: str,
request: OAuth2Request,
) -> None:
"""Save authorization_code for later use."""
if not request.payload:
raise ValueError('Missing request payload')
@@ -81,7 +87,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
now_ = now()
ttl_ = ttl(start_dt=now_, minutes=10)
with oauth2_layer.transact_writer() as transact:
with dyn.transact_writer() as transact:
transact.put(
item={
'id': 'OAUTH2#CODE',
@@ -116,7 +122,8 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
code: str,
client: ClientMixin,
) -> OAuth2AuthorizationCode:
auth_code = oauth2_layer.get_item(
"""Get authorization_code from previously savings."""
auth_code = dyn.get_item(
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
)
@@ -129,16 +136,24 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
self,
authorization_code: OAuth2AuthorizationCode,
) -> None:
oauth2_layer.delete_item(
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}'),
"""Delete authorization code from database or cache."""
dyn.delete_item(
KeyPair(
pk='OAUTH2#CODE',
sk=f'CODE#{authorization_code.code}',
),
)
def authenticate_user(
self,
authorization_code: OAuth2AuthorizationCode,
) -> dict:
user = oauth2_layer.get_item(
KeyPair(pk=authorization_code.user_id, sk='0'),
"""Authenticate the user related to this authorization_code."""
user = dyn.get_item(
KeyPair(
pk=authorization_code.user_id,
sk='0',
),
)
return pick(('id', 'name', 'email', 'email_verified'), user)
@@ -154,10 +169,13 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
'client_secret_post',
'none',
]
# The authorization server MAY issue a new refresh token
INCLUDE_NEW_REFRESH_TOKEN = True
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
token = oauth2_layer.collection.get_item(
"""Get token information with refresh_token string."""
token = dyn.collection.get_item(
KeyPair(
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{refresh_token}',
@@ -175,16 +193,90 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
)
def authenticate_user(self, refresh_token: TokenMixin):
"""Authenticate the user related to this credential."""
return refresh_token.get_user()
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
logger.info('Revoking old refresh token', refresh_token=refresh_token)
token = getattr(refresh_token, 'refresh_token', None)
"""The authorization server MAY revoke the old refresh token after
issuing a new refresh token to the client."""
if token:
oauth2_layer.delete_item(
KeyPair(pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}')
logger.debug('Revoking old refresh token', refresh_token=refresh_token)
token = getattr(refresh_token, 'refresh_token', None)
user = refresh_token.get_user()
with dyn.transact_writer() as transact:
transact.delete(
key=KeyPair(
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{token}',
)
)
transact.delete(
key=KeyPair(
pk=user.get('id'),
sk=f'SESSION#REFRESH_TOKEN#{token}',
)
)
class RevocationEndpoint(rfc7009.RevocationEndpoint):
def query_token( # type: ignore
self,
token_string: str,
token_type_hint: str | None = None,
):
result = dyn.collection.get_items(
TransactKey('OAUTH2#TOKEN')
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
flatten_top=False,
)
if not result:
return None
logger.debug('Tokens retrieved', result=result)
if not token_type_hint:
token_type_hint = (
'refresh_token' if 'refresh_token' in result else 'access_token'
)
token = result[token_type_hint]
return OAuth2Token(
expires_in=int(token['expires_in']),
issued_at=int(token['issued_at']),
**{token_type_hint: token_string},
**omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token),
)
def revoke_token(
self,
token: OAuth2Token,
request: OAuth2Request,
):
user_id = token.user['id']
r = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
with dyn.transact_writer() as transact:
# Revoke all sessions, access tokens, and refresh tokens
for x in r['items']:
pk, sk = x['id'], x['sk']
*_, kind, idx = sk.split('#')
transact.delete(key=KeyPair(pk, sk))
transact.delete(
key=KeyPair(
pk='SESSION',
sk=idx,
)
if kind == 'SESSION'
else KeyPair(
pk='OAUTH2#TOKEN',
sk=f'{kind}#{idx}',
)
)
class IssuerParameter(rfc9207.IssuerParameter):
@@ -207,7 +299,7 @@ class IssuerParameter(rfc9207.IssuerParameter):
return ISSUER
server = AuthorizationServer(persistence_layer=oauth2_layer)
server = AuthorizationServer(persistence_layer=dyn)
server.register_grant(
AuthorizationCodeGrant,
[
@@ -216,4 +308,5 @@ server.register_grant(
],
)
server.register_grant(RefreshTokenGrant)
server.register_endpoint(RevocationEndpoint)
server.register_extension(IssuerParameter())

View File

@@ -6,8 +6,12 @@ from authlib.oauth2.rfc6749 import errors
from authlib.oauth2.rfc6749.util import scope_to_list
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Router
from aws_lambda_powertools.event_handler.exceptions import BadRequestError, ServiceError
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from aws_lambda_powertools.event_handler.exceptions import (
BadRequestError,
ServiceError,
UnauthorizedError,
)
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair, SortKey
from boto3clients import dynamodb_client
from config import ISSUER, JWT_ALGORITHM, JWT_SECRET, OAUTH2_TABLE
@@ -15,7 +19,7 @@ from oauth2 import server
router = Router()
logger = Logger(__name__)
oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
@router.get('/authorize')
@@ -36,9 +40,8 @@ def authorize():
client_scopes = set(scope_to_list(grant.client.scope))
user_scopes = set(scope_to_list(session_scope)) if session_scope else set()
if not client_scopes.issubset(
user_scopes | {'openid', 'email', 'profile', 'offline_access'}
):
# Deny authorization if user has no scopes matching the client request
if not user_scopes & client_scopes:
raise errors.InvalidScopeError(status_code=HTTPStatus.UNAUTHORIZED)
return server.create_authorization_response(
@@ -69,15 +72,27 @@ def verify_session(session_id: str) -> tuple[str, str | None]:
},
)
oauth2_layer.collection.get_item(
user = dyn.collection.get_items(
KeyPair(
pk='SESSION',
sk=payload['sid'],
rename_key='session',
)
+ KeyPair(
pk=payload['sub'],
sk=SortKey(
sk='SCOPE',
path_spec='scope',
rename_key='scope',
),
),
exc_cls=SessionRevokedError,
flatten_top=False,
)
return payload['sub'], payload.get('scope')
if 'session' not in user:
raise SessionRevokedError('Session revoked')
return payload['sub'], user.get('scope')
def _parse_cookies(cookies: list[str] | None) -> dict[str, str]:
@@ -94,6 +109,4 @@ def _parse_cookies(cookies: list[str] | None) -> dict[str, str]:
return parsed_cookies
class SessionRevokedError(BadRequestError):
def __init__(self, *_):
super().__init__('Session revoked')
class SessionRevokedError(UnauthorizedError): ...

View File

@@ -11,6 +11,7 @@ def openid_configuration():
'issuer': ISSUER,
'authorization_endpoint': f'{ISSUER}/authorize',
'token_endpoint': f'{ISSUER}/token',
'revocation_endpoint': f'{ISSUER}/revoke',
'userinfo_endpoint': f'{ISSUER}/userinfo',
'jwks_uri': f'{ISSUER}/jwks.json',
'scopes_supported': OAUTH2_SCOPES_SUPPORTED.split(),

View File

@@ -0,0 +1,13 @@
from aws_lambda_powertools.event_handler.api_gateway import Router
from oauth2 import RevocationEndpoint, server
router = Router()
@router.post('/revoke')
def revoke():
return server.create_endpoint_response(
RevocationEndpoint.ENDPOINT_NAME,
router.current_event,
)

View File

@@ -15,10 +15,16 @@ from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair, SortKey
from passlib.hash import pbkdf2_sha256
from boto3clients import dynamodb_client
from config import ISSUER, JWT_ALGORITHM, JWT_EXP_SECONDS, JWT_SECRET, OAUTH2_TABLE
from config import (
ISSUER,
JWT_ALGORITHM,
JWT_SECRET,
OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
OAUTH2_TABLE,
)
router = Router()
oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
@router.post('/session')
@@ -26,11 +32,7 @@ def session(
username: Annotated[str, Body()],
password: Annotated[str, Body()],
):
(
user_id,
password_hash,
scope,
) = _get_user(username)
user_id, password_hash = _get_user(username)
if not pbkdf2_sha256.verify(password, password_hash):
raise ForbiddenError('Invalid credentials')
@@ -40,28 +42,27 @@ def session(
cookies=[
Cookie(
name='session_id',
value=new_session(user_id, scope),
value=new_session(user_id),
http_only=True,
secure=True,
same_site=None,
max_age=JWT_EXP_SECONDS,
max_age=OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
)
],
)
def _get_user(username: str) -> tuple[str, str, str | None]:
def _get_user(username: str) -> tuple[str, str]:
sk = SortKey(username, path_spec='user_id')
user = oauth2_layer.collection.get_items(
user = dyn.collection.get_items(
KeyPair(pk='email', sk=sk, rename_key=sk.path_spec)
+ KeyPair(pk='cpf', sk=sk, rename_key=sk.path_spec),
flatten_top=False,
)
if not user:
raise UserNotFoundError()
userdata = oauth2_layer.collection.get_items(
password = dyn.collection.get_item(
KeyPair(
pk=user['user_id'],
sk=SortKey(
@@ -69,46 +70,34 @@ def _get_user(username: str) -> tuple[str, str, str | None]:
path_spec='hash',
rename_key='password',
),
)
+ KeyPair(
pk=user['user_id'],
sk=SortKey(
sk='SCOPE',
path_spec='scope',
rename_key='scope',
),
),
flatten_top=False,
exc_cls=UserNotFoundError,
)
if not userdata:
raise UserNotFoundError()
return user['user_id'], userdata['password'], userdata.get('scope')
return user['user_id'], password
def new_session(sub: str, scope: str | None) -> str:
def new_session(sub: str) -> str:
session_id = str(uuid4())
now_ = now()
sid = str(uuid4())
exp = ttl(start_dt=now_, seconds=JWT_EXP_SECONDS)
exp = ttl(start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN)
token = jwt.encode(
{
'sid': sid,
'sid': session_id,
'sub': sub,
'iss': ISSUER,
'iat': int(now_.timestamp()),
'exp': exp,
'scope': scope,
},
JWT_SECRET,
algorithm=JWT_ALGORITHM,
)
with oauth2_layer.transact_writer() as transact:
with dyn.transact_writer() as transact:
transact.put(
item={
'id': 'SESSION',
'sk': sid,
'sk': session_id,
'user_id': sub,
'ttl': exp,
'created_at': now_,
@@ -117,7 +106,7 @@ def new_session(sub: str, scope: str | None) -> str:
transact.put(
item={
'id': sub,
'sk': f'SESSION#{sid}',
'sk': f'SESSION#{session_id}',
'ttl': exp,
'created_at': now_,
}