This commit is contained in:
2025-08-08 13:32:37 -03:00
parent 78c4a4ad30
commit b7143ea634
8 changed files with 117 additions and 93 deletions

View File

@@ -9,5 +9,4 @@ JWT_SECRET: str = os.environ.get('JWT_SECRET') # type: ignore
JWT_ALGORITHM = 'HS256' JWT_ALGORITHM = 'HS256'
JWT_EXP_SECONDS = 900 # 15 minutes JWT_EXP_SECONDS = 900 # 15 minutes
ACCESS_TOKEN_EXP_SECONDS = 3600 # 1 hour OAUTH2_REFRESH_TOKEN_EXPIRES_IN = 30 * 86400 # 30 days
REFRESH_TOKEN_EXP_SECONDS = 14 * 86400 # 14 days

View File

@@ -12,7 +12,7 @@ from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import
from layercake.dateutils import now, ttl from layercake.dateutils import now, ttl
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from config import ACCESS_TOKEN_EXP_SECONDS, REFRESH_TOKEN_EXP_SECONDS from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN
from .client import OAuth2Client from .client import OAuth2Client
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
@@ -21,6 +21,12 @@ DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED') OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
GRANT_TYPES_EXPIRES_IN = {
'refresh_token': 900,
'authorization_code': 900,
}
class AuthorizationServer(oauth2.AuthorizationServer): class AuthorizationServer(oauth2.AuthorizationServer):
def __init__( def __init__(
self, self,
@@ -35,11 +41,17 @@ class AuthorizationServer(oauth2.AuthorizationServer):
) )
) )
def expires_in(client, grant_type: str) -> int:
return GRANT_TYPES_EXPIRES_IN.get(
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
)
self.register_token_generator( self.register_token_generator(
'default', 'default',
BearerTokenGenerator( BearerTokenGenerator(
access_token_generator=create_token_generator(42), access_token_generator=create_token_generator(42),
refresh_token_generator=create_token_generator(48), refresh_token_generator=create_token_generator(48),
expires_generator=expires_in,
), ),
) )
@@ -60,6 +72,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
refresh_token = token.get('refresh_token') refresh_token = token.get('refresh_token')
token_type = token['token_type'] token_type = token['token_type']
scope = token['scope'] scope = token['scope']
expires_in = int(token['expires_in'])
issued_at = int(now_.timestamp()) issued_at = int(now_.timestamp())
with self._persistence_layer.transact_writer() as transact: with self._persistence_layer.transact_writer() as transact:
@@ -72,9 +85,9 @@ class AuthorizationServer(oauth2.AuthorizationServer):
'refresh_token': refresh_token, 'refresh_token': refresh_token,
'scope': scope, 'scope': scope,
'user': request.user, 'user': request.user,
'expires_in': ACCESS_TOKEN_EXP_SECONDS, 'expires_in': expires_in,
'issued_at': issued_at, 'issued_at': issued_at,
'ttl': ttl(start_dt=now_, seconds=ACCESS_TOKEN_EXP_SECONDS), 'ttl': ttl(start_dt=now_, seconds=expires_in),
}, },
) )
@@ -88,9 +101,11 @@ class AuthorizationServer(oauth2.AuthorizationServer):
'access_token': access_token, 'access_token': access_token,
'scope': scope, 'scope': scope,
'user': request.user, 'user': request.user,
'expires_in': REFRESH_TOKEN_EXP_SECONDS, 'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
'issued_at': issued_at, 'issued_at': issued_at,
'ttl': ttl(start_dt=now_, seconds=REFRESH_TOKEN_EXP_SECONDS), 'ttl': ttl(
start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN
),
}, },
) )
@@ -117,23 +132,35 @@ class AuthorizationServer(oauth2.AuthorizationServer):
) )
def create_oauth2_request( def create_oauth2_request(
self, request: APIGatewayProxyEventV2 self,
request: APIGatewayProxyEventV2,
) -> APIGatewayOAuth2Request: ) -> APIGatewayOAuth2Request:
return APIGatewayOAuth2Request(request) return APIGatewayOAuth2Request(request)
def create_json_request( def create_json_request(
self, request: APIGatewayProxyEventV2 self,
request: APIGatewayProxyEventV2,
) -> APIGatewayJsonRequest: ) -> APIGatewayJsonRequest:
return APIGatewayJsonRequest(request) return APIGatewayJsonRequest(request)
def handle_response(self, status: int, body, headers): def handle_response(
self,
status: int,
body,
headers,
):
return Response( return Response(
status_code=status, status_code=status,
body=body, body=body,
headers=headers, headers=headers,
) )
def send_signal(self, name: str, *args, **kwargs) -> None: def send_signal(
self,
name: str,
*args: object,
**kwargs: object,
) -> None:
# after_authenticate_client # after_authenticate_client
# when client is authenticated # when client is authenticated

View File

@@ -1,4 +1,5 @@
import secrets import secrets
from collections.abc import Collection
from authlib.oauth2.rfc6749 import ( from authlib.oauth2.rfc6749 import (
ClientMixin, ClientMixin,
@@ -29,7 +30,7 @@ class OAuth2Client(ClientMixin):
def get_client_id(self): def get_client_id(self):
return self.client_id return self.client_id
def get_allowed_scope(self, scope) -> str: def get_allowed_scope(self, scope: Collection[str]) -> str:
if not scope: if not scope:
return '' return ''
@@ -37,24 +38,26 @@ class OAuth2Client(ClientMixin):
scopes = scope_to_list(scope) scopes = scope_to_list(scope)
return list_to_scope([s for s in scopes if s in allowed]) return list_to_scope([s for s in scopes if s in allowed])
def get_default_redirect_uri(self) -> str: # type: ignore def get_default_redirect_uri(self) -> str:
if self.redirect_uris: if self.redirect_uris:
return self.redirect_uris[0] return self.redirect_uris[0]
def check_response_type(self, response_type): raise ValueError('Missing redirect_uris')
def check_response_type(self, response_type: str) -> bool:
return response_type in self.response_types return response_type in self.response_types
def check_redirect_uri(self, redirect_uri): def check_redirect_uri(self, redirect_uri: str) -> bool:
return redirect_uri in self.redirect_uris return redirect_uri in self.redirect_uris
def check_endpoint_auth_method(self, method, endpoint): def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
if endpoint == 'token': if endpoint == 'token':
return self.token_endpoint_auth_method == method return self.token_endpoint_auth_method == method
return True return True
def check_grant_type(self, grant_type): def check_grant_type(self, grant_type: str) -> bool:
return grant_type in self.grant_types return grant_type in self.grant_types
def check_client_secret(self, client_secret): def check_client_secret(self, client_secret: str) -> bool:
return secrets.compare_digest(self.client_secret, client_secret) return secrets.compare_digest(self.client_secret, client_secret)

View File

@@ -5,7 +5,7 @@ from authlib.oauth2.rfc6749 import (
ClientMixin, ClientMixin,
TokenMixin, TokenMixin,
) )
from layercake.dateutils import fromisoformat from layercake.dateutils import fromisoformat, now
class OAuth2AuthorizationCode(AuthorizationCodeMixin): class OAuth2AuthorizationCode(AuthorizationCodeMixin):
@@ -32,25 +32,26 @@ class OAuth2AuthorizationCode(AuthorizationCodeMixin):
self.code_challenge_method = code_challenge_method self.code_challenge_method = code_challenge_method
self.nonce = nonce self.nonce = nonce
auth_time = fromisoformat(kwargs.get('created_at', '')) or now() created_at = kwargs.get('created_at', '')
auth_time = fromisoformat(created_at) or now()
self.auth_time = int(auth_time.timestamp()) self.auth_time = int(auth_time.timestamp())
def get_redirect_uri(self): def get_redirect_uri(self) -> str:
return self.redirect_uri return self.redirect_uri
def get_scope(self): def get_scope(self) -> str:
return self.scope return self.scope
def get_nonce(self): def get_nonce(self) -> str | None:
return self.nonce return self.nonce
def get_auth_time(self): def get_auth_time(self) -> int:
return self.auth_time return self.auth_time
def get_acr(self): def get_acr(self):
return '0' return '0'
def get_amr(self): def get_amr(self) -> list:
return [] return []
@@ -77,7 +78,7 @@ class OAuth2Token(TokenMixin):
def get_user(self) -> dict: def get_user(self) -> dict:
return self.user return self.user
def check_client(self, client: ClientMixin): def check_client(self, client: ClientMixin) -> bool:
return self.client_id == client.get_client_id() return self.client_id == client.get_client_id()
def get_scope(self) -> str: def get_scope(self) -> str:

View File

@@ -1,7 +1,7 @@
from datetime import timedelta from datetime import timedelta
from aws_lambda_powertools.event_handler.exceptions import ForbiddenError from aws_lambda_powertools.event_handler.exceptions import ForbiddenError
from jose import ExpiredSignatureError, JWTError, jwt from jose import jwt
from layercake.dateutils import now from layercake.dateutils import now
from config import ( from config import (
@@ -9,7 +9,7 @@ from config import (
JWT_ALGORITHM, JWT_ALGORITHM,
JWT_EXP_SECONDS, JWT_EXP_SECONDS,
JWT_SECRET, JWT_SECRET,
REFRESH_TOKEN_EXP_SECONDS, OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
) )
@@ -27,10 +27,11 @@ def generate_jwt(user_id: str, email: str) -> str:
def generate_refresh_token(user_id: str) -> str: def generate_refresh_token(user_id: str) -> str:
now_ = now() now_ = now()
exp = now_ + timedelta(seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN)
payload = { payload = {
'sub': user_id, 'sub': user_id,
'iat': int(now_.timestamp()), 'iat': int(now_.timestamp()),
'exp': int((now_ + timedelta(seconds=REFRESH_TOKEN_EXP_SECONDS)).timestamp()), 'exp': int(exp.timestamp()),
'iss': ISSUER, 'iss': ISSUER,
'typ': 'refresh', 'typ': 'refresh',
} }
@@ -38,19 +39,14 @@ def generate_refresh_token(user_id: str) -> str:
def verify_jwt(token: str) -> dict: def verify_jwt(token: str) -> dict:
try: payload = jwt.decode(
payload = jwt.decode( token,
token, JWT_SECRET,
JWT_SECRET, algorithms=[JWT_ALGORITHM],
algorithms=[JWT_ALGORITHM], issuer=ISSUER,
issuer=ISSUER, options={
options={ 'require': ['exp', 'sub', 'iss'],
'require': ['exp', 'sub', 'iss'], 'leeway': 60,
'leeway': 60, },
}, )
) return payload
return payload
except ExpiredSignatureError:
raise ForbiddenError('Token expired')
except JWTError:
raise ForbiddenError('Invalid token')

View File

@@ -5,7 +5,7 @@ from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo from authlib.oidc.core import UserInfo
from layercake.dateutils import now, ttl from layercake.dateutils import now, ttl
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
from layercake.funcs import pick from layercake.funcs import omit, pick
from boto3clients import dynamodb_client from boto3clients import dynamodb_client
from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE
@@ -124,10 +124,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
authorization_code: OAuth2AuthorizationCode, authorization_code: OAuth2AuthorizationCode,
) -> None: ) -> None:
oauth2_layer.delete_item( oauth2_layer.delete_item(
KeyPair( KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}'),
pk='OAUTH2#CODE',
sk=f'CODE#{authorization_code.code}',
),
) )
def authenticate_user( def authenticate_user(
@@ -135,17 +132,18 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
authorization_code: OAuth2AuthorizationCode, authorization_code: OAuth2AuthorizationCode,
) -> dict: ) -> dict:
user = oauth2_layer.get_item( user = oauth2_layer.get_item(
KeyPair( KeyPair(pk=authorization_code.user_id, sk='0'),
pk=authorization_code.user_id,
sk='0',
),
) )
return pick(('id', 'name', 'email', 'email_verified'), user) return pick(('id', 'name', 'email', 'email_verified'), user)
class RefreshTokenGrant(grants.RefreshTokenGrant): class RefreshTokenGrant(grants.RefreshTokenGrant):
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none']
INCLUDE_NEW_REFRESH_TOKEN = True INCLUDE_NEW_REFRESH_TOKEN = True
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
'none',
]
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin: def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
token = oauth2_layer.get_item( token = oauth2_layer.get_item(
@@ -156,26 +154,19 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
) )
return OAuth2Token( return OAuth2Token(
client_id=token['client_id'],
scope=token['scope'],
expires_in=int(token['expires_in']), expires_in=int(token['expires_in']),
issued_at=int(token['issued_at']), issued_at=int(token['issued_at']),
user=token['user'],
refresh_token=refresh_token, refresh_token=refresh_token,
**omit(('expires_in', 'issued_at'), token),
) )
def authenticate_user(self, refresh_token: TokenMixin): def authenticate_user(self, refresh_token: TokenMixin):
return refresh_token.get_user() return refresh_token.get_user()
def revoke_old_credential(self, refresh_token: TokenMixin) -> None: def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
refresh_token_ = getattr(refresh_token, 'refresh_token') if token := getattr(refresh_token, 'refresh_token', None):
if refresh_token_:
oauth2_layer.delete_item( oauth2_layer.delete_item(
KeyPair( KeyPair(pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}')
pk='OAUTH2#TOKEN',
sk=f'REFRESH_TOKEN#{refresh_token_}',
)
) )

View File

@@ -7,6 +7,7 @@ from authlib.oauth2.rfc6749 import errors
from aws_lambda_powertools import Logger from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler import Response
from aws_lambda_powertools.event_handler.api_gateway import Router from aws_lambda_powertools.event_handler.api_gateway import Router
from jose.exceptions import JWTError
from jose_ import verify_jwt from jose_ import verify_jwt
from oauth2 import server from oauth2 import server
@@ -18,33 +19,21 @@ logger = Logger(__name__)
@router.get('/authorize') @router.get('/authorize')
def authorize(): def authorize():
current_event = router.current_event current_event = router.current_event
cookies = _parse_cookies(current_event.get('cookies', [])) # type: ignore cookies = _parse_cookies(current_event.get('cookies', []))
id_token = cookies.get('id_token') id_token = cookies.get('id_token')
continue_url = quote( continue_url = _build_continue_url(
urlunparse( current_event.path,
ParseResult( current_event.query_string_parameters,
scheme='',
netloc='',
path=current_event.path,
params='',
query=urlencode(current_event.query_string_parameters),
fragment='',
)
),
safe='',
) )
login_url = f'/login?continue={continue_url}' login_url = f'/login?continue={continue_url}'
if not id_token:
return Response(
status_code=HTTPStatus.FOUND,
headers={'Location': login_url},
)
try: try:
if not id_token:
raise ValueError('Missing id_token')
user = verify_jwt(id_token) user = verify_jwt(id_token)
except Exception as exc: except (ValueError, JWTError):
logger.exception(exc)
return Response( return Response(
status_code=HTTPStatus.FOUND, status_code=HTTPStatus.FOUND,
headers={'Location': login_url}, headers={'Location': login_url},
@@ -82,3 +71,11 @@ def _parse_cookies(cookies: list[str] | None) -> dict[str, str]:
parsed_cookies.update({k: morsel.value for k, morsel in c.items()}) parsed_cookies.update({k: morsel.value for k, morsel in c.items()})
return parsed_cookies return parsed_cookies
def _build_continue_url(
path: str,
query_string_parameters: dict,
) -> str:
query = urlencode(query_string_parameters)
return quote(urlunparse(ParseResult('', '', path, '', query, '')), safe='')

View File

@@ -1,5 +1,4 @@
import json import json
import pprint
from http import HTTPMethod, HTTPStatus from http import HTTPMethod, HTTPStatus
from urllib.parse import urlencode from urllib.parse import urlencode
@@ -36,10 +35,11 @@ def test_token(
), ),
lambda_context, lambda_context,
) )
assert r['statusCode'] == HTTPStatus.OK auth_token = json.loads(r['body'])
data = json.loads(r['body'])
assert r['statusCode'] == HTTPStatus.OK
assert auth_token['expires_in'] == 900
# print(data)
r = dynamodb_persistence_layer.query( r = dynamodb_persistence_layer.query(
key_cond_expr='#pk = :pk', key_cond_expr='#pk = :pk',
expr_attr_name={ expr_attr_name={
@@ -49,7 +49,7 @@ def test_token(
':pk': 'OAUTH2#TOKEN', ':pk': 'OAUTH2#TOKEN',
}, },
) )
# pprint.pp(r['items']) assert len(r['items']) == 2
r = app.lambda_handler( r = app.lambda_handler(
http_api_proxy( http_api_proxy(
@@ -61,7 +61,7 @@ def test_token(
body=urlencode( body=urlencode(
{ {
'grant_type': 'refresh_token', 'grant_type': 'refresh_token',
'refresh_token': data['refresh_token'], 'refresh_token': auth_token['refresh_token'],
'client_id': client_id, 'client_id': client_id,
} }
), ),
@@ -70,4 +70,14 @@ def test_token(
) )
assert r['statusCode'] == HTTPStatus.OK assert r['statusCode'] == HTTPStatus.OK
# print(r['body'])
r = dynamodb_persistence_layer.query(
key_cond_expr='#pk = :pk',
expr_attr_name={
'#pk': 'id',
},
expr_attr_values={
':pk': 'OAUTH2#TOKEN',
},
)
assert len(r['items']) == 3