diff --git a/id.saladeaula.digital/app/config.py b/id.saladeaula.digital/app/config.py index 0e63d61..0c8206a 100644 --- a/id.saladeaula.digital/app/config.py +++ b/id.saladeaula.digital/app/config.py @@ -9,5 +9,4 @@ JWT_SECRET: str = os.environ.get('JWT_SECRET') # type: ignore JWT_ALGORITHM = 'HS256' JWT_EXP_SECONDS = 900 # 15 minutes -ACCESS_TOKEN_EXP_SECONDS = 3600 # 1 hour -REFRESH_TOKEN_EXP_SECONDS = 14 * 86400 # 14 days +OAUTH2_REFRESH_TOKEN_EXPIRES_IN = 30 * 86400 # 30 days diff --git a/id.saladeaula.digital/app/integrations/apigateway_oauth2/authorization_server.py b/id.saladeaula.digital/app/integrations/apigateway_oauth2/authorization_server.py index 451be90..ff9a022 100644 --- a/id.saladeaula.digital/app/integrations/apigateway_oauth2/authorization_server.py +++ b/id.saladeaula.digital/app/integrations/apigateway_oauth2/authorization_server.py @@ -12,7 +12,7 @@ from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair -from config import ACCESS_TOKEN_EXP_SECONDS, REFRESH_TOKEN_EXP_SECONDS +from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN from .client import OAuth2Client from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request @@ -21,6 +21,12 @@ DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY') OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED') +GRANT_TYPES_EXPIRES_IN = { + 'refresh_token': 900, + 'authorization_code': 900, +} + + class AuthorizationServer(oauth2.AuthorizationServer): def __init__( self, @@ -35,11 +41,17 @@ class AuthorizationServer(oauth2.AuthorizationServer): ) ) + def expires_in(client, grant_type: str) -> int: + return GRANT_TYPES_EXPIRES_IN.get( + grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN + ) + self.register_token_generator( 'default', BearerTokenGenerator( access_token_generator=create_token_generator(42), refresh_token_generator=create_token_generator(48), + expires_generator=expires_in, ), ) @@ -60,6 +72,7 @@ class AuthorizationServer(oauth2.AuthorizationServer): refresh_token = token.get('refresh_token') token_type = token['token_type'] scope = token['scope'] + expires_in = int(token['expires_in']) issued_at = int(now_.timestamp()) with self._persistence_layer.transact_writer() as transact: @@ -72,9 +85,9 @@ class AuthorizationServer(oauth2.AuthorizationServer): 'refresh_token': refresh_token, 'scope': scope, 'user': request.user, - 'expires_in': ACCESS_TOKEN_EXP_SECONDS, + 'expires_in': expires_in, 'issued_at': issued_at, - 'ttl': ttl(start_dt=now_, seconds=ACCESS_TOKEN_EXP_SECONDS), + 'ttl': ttl(start_dt=now_, seconds=expires_in), }, ) @@ -88,9 +101,11 @@ class AuthorizationServer(oauth2.AuthorizationServer): 'access_token': access_token, 'scope': scope, 'user': request.user, - 'expires_in': REFRESH_TOKEN_EXP_SECONDS, + 'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN, 'issued_at': issued_at, - 'ttl': ttl(start_dt=now_, seconds=REFRESH_TOKEN_EXP_SECONDS), + 'ttl': ttl( + start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN + ), }, ) @@ -117,23 +132,35 @@ class AuthorizationServer(oauth2.AuthorizationServer): ) def create_oauth2_request( - self, request: APIGatewayProxyEventV2 + self, + request: APIGatewayProxyEventV2, ) -> APIGatewayOAuth2Request: return APIGatewayOAuth2Request(request) def create_json_request( - self, request: APIGatewayProxyEventV2 + self, + request: APIGatewayProxyEventV2, ) -> APIGatewayJsonRequest: return APIGatewayJsonRequest(request) - def handle_response(self, status: int, body, headers): + def handle_response( + self, + status: int, + body, + headers, + ): return Response( status_code=status, body=body, headers=headers, ) - def send_signal(self, name: str, *args, **kwargs) -> None: + def send_signal( + self, + name: str, + *args: object, + **kwargs: object, + ) -> None: # after_authenticate_client # when client is authenticated diff --git a/id.saladeaula.digital/app/integrations/apigateway_oauth2/client.py b/id.saladeaula.digital/app/integrations/apigateway_oauth2/client.py index 0480fd8..dec62f9 100644 --- a/id.saladeaula.digital/app/integrations/apigateway_oauth2/client.py +++ b/id.saladeaula.digital/app/integrations/apigateway_oauth2/client.py @@ -1,4 +1,5 @@ import secrets +from collections.abc import Collection from authlib.oauth2.rfc6749 import ( ClientMixin, @@ -29,7 +30,7 @@ class OAuth2Client(ClientMixin): def get_client_id(self): return self.client_id - def get_allowed_scope(self, scope) -> str: + def get_allowed_scope(self, scope: Collection[str]) -> str: if not scope: return '' @@ -37,24 +38,26 @@ class OAuth2Client(ClientMixin): scopes = scope_to_list(scope) return list_to_scope([s for s in scopes if s in allowed]) - def get_default_redirect_uri(self) -> str: # type: ignore + def get_default_redirect_uri(self) -> str: if self.redirect_uris: return self.redirect_uris[0] - def check_response_type(self, response_type): + raise ValueError('Missing redirect_uris') + + def check_response_type(self, response_type: str) -> bool: return response_type in self.response_types - def check_redirect_uri(self, redirect_uri): + def check_redirect_uri(self, redirect_uri: str) -> bool: return redirect_uri in self.redirect_uris - def check_endpoint_auth_method(self, method, endpoint): + def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: if endpoint == 'token': return self.token_endpoint_auth_method == method return True - def check_grant_type(self, grant_type): + def check_grant_type(self, grant_type: str) -> bool: return grant_type in self.grant_types - def check_client_secret(self, client_secret): + def check_client_secret(self, client_secret: str) -> bool: return secrets.compare_digest(self.client_secret, client_secret) diff --git a/id.saladeaula.digital/app/integrations/apigateway_oauth2/tokens.py b/id.saladeaula.digital/app/integrations/apigateway_oauth2/tokens.py index 2698a15..8bc6cfb 100644 --- a/id.saladeaula.digital/app/integrations/apigateway_oauth2/tokens.py +++ b/id.saladeaula.digital/app/integrations/apigateway_oauth2/tokens.py @@ -5,7 +5,7 @@ from authlib.oauth2.rfc6749 import ( ClientMixin, TokenMixin, ) -from layercake.dateutils import fromisoformat +from layercake.dateutils import fromisoformat, now class OAuth2AuthorizationCode(AuthorizationCodeMixin): @@ -32,25 +32,26 @@ class OAuth2AuthorizationCode(AuthorizationCodeMixin): self.code_challenge_method = code_challenge_method self.nonce = nonce - auth_time = fromisoformat(kwargs.get('created_at', '')) or now() + created_at = kwargs.get('created_at', '') + auth_time = fromisoformat(created_at) or now() self.auth_time = int(auth_time.timestamp()) - def get_redirect_uri(self): + def get_redirect_uri(self) -> str: return self.redirect_uri - def get_scope(self): + def get_scope(self) -> str: return self.scope - def get_nonce(self): + def get_nonce(self) -> str | None: return self.nonce - def get_auth_time(self): + def get_auth_time(self) -> int: return self.auth_time def get_acr(self): return '0' - def get_amr(self): + def get_amr(self) -> list: return [] @@ -77,7 +78,7 @@ class OAuth2Token(TokenMixin): def get_user(self) -> dict: return self.user - def check_client(self, client: ClientMixin): + def check_client(self, client: ClientMixin) -> bool: return self.client_id == client.get_client_id() def get_scope(self) -> str: diff --git a/id.saladeaula.digital/app/jose_.py b/id.saladeaula.digital/app/jose_.py index 4ab0d77..e7d15f9 100644 --- a/id.saladeaula.digital/app/jose_.py +++ b/id.saladeaula.digital/app/jose_.py @@ -1,7 +1,7 @@ from datetime import timedelta from aws_lambda_powertools.event_handler.exceptions import ForbiddenError -from jose import ExpiredSignatureError, JWTError, jwt +from jose import jwt from layercake.dateutils import now from config import ( @@ -9,7 +9,7 @@ from config import ( JWT_ALGORITHM, JWT_EXP_SECONDS, JWT_SECRET, - REFRESH_TOKEN_EXP_SECONDS, + OAUTH2_REFRESH_TOKEN_EXPIRES_IN, ) @@ -27,10 +27,11 @@ def generate_jwt(user_id: str, email: str) -> str: def generate_refresh_token(user_id: str) -> str: now_ = now() + exp = now_ + timedelta(seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN) payload = { 'sub': user_id, 'iat': int(now_.timestamp()), - 'exp': int((now_ + timedelta(seconds=REFRESH_TOKEN_EXP_SECONDS)).timestamp()), + 'exp': int(exp.timestamp()), 'iss': ISSUER, 'typ': 'refresh', } @@ -38,19 +39,14 @@ def generate_refresh_token(user_id: str) -> str: def verify_jwt(token: str) -> dict: - try: - payload = jwt.decode( - token, - JWT_SECRET, - algorithms=[JWT_ALGORITHM], - issuer=ISSUER, - options={ - 'require': ['exp', 'sub', 'iss'], - 'leeway': 60, - }, - ) - return payload - except ExpiredSignatureError: - raise ForbiddenError('Token expired') - except JWTError: - raise ForbiddenError('Invalid token') + payload = jwt.decode( + token, + JWT_SECRET, + algorithms=[JWT_ALGORITHM], + issuer=ISSUER, + options={ + 'require': ['exp', 'sub', 'iss'], + 'leeway': 60, + }, + ) + return payload diff --git a/id.saladeaula.digital/app/oauth2.py b/id.saladeaula.digital/app/oauth2.py index ec0d6fa..96347f4 100644 --- a/id.saladeaula.digital/app/oauth2.py +++ b/id.saladeaula.digital/app/oauth2.py @@ -5,7 +5,7 @@ from authlib.oidc.core import OpenIDCode as OpenIDCode_ from authlib.oidc.core import UserInfo from layercake.dateutils import now, ttl from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair -from layercake.funcs import pick +from layercake.funcs import omit, pick from boto3clients import dynamodb_client from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE @@ -124,10 +124,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): authorization_code: OAuth2AuthorizationCode, ) -> None: oauth2_layer.delete_item( - KeyPair( - pk='OAUTH2#CODE', - sk=f'CODE#{authorization_code.code}', - ), + KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}'), ) def authenticate_user( @@ -135,17 +132,18 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): authorization_code: OAuth2AuthorizationCode, ) -> dict: user = oauth2_layer.get_item( - KeyPair( - pk=authorization_code.user_id, - sk='0', - ), + KeyPair(pk=authorization_code.user_id, sk='0'), ) return pick(('id', 'name', 'email', 'email_verified'), user) class RefreshTokenGrant(grants.RefreshTokenGrant): - TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none'] INCLUDE_NEW_REFRESH_TOKEN = True + TOKEN_ENDPOINT_AUTH_METHODS = [ + 'client_secret_basic', + 'client_secret_post', + 'none', + ] def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin: token = oauth2_layer.get_item( @@ -156,26 +154,19 @@ class RefreshTokenGrant(grants.RefreshTokenGrant): ) return OAuth2Token( - client_id=token['client_id'], - scope=token['scope'], expires_in=int(token['expires_in']), issued_at=int(token['issued_at']), - user=token['user'], refresh_token=refresh_token, + **omit(('expires_in', 'issued_at'), token), ) def authenticate_user(self, refresh_token: TokenMixin): return refresh_token.get_user() def revoke_old_credential(self, refresh_token: TokenMixin) -> None: - refresh_token_ = getattr(refresh_token, 'refresh_token') - - if refresh_token_: + if token := getattr(refresh_token, 'refresh_token', None): oauth2_layer.delete_item( - KeyPair( - pk='OAUTH2#TOKEN', - sk=f'REFRESH_TOKEN#{refresh_token_}', - ) + KeyPair(pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}') ) diff --git a/id.saladeaula.digital/app/routes/authorize.py b/id.saladeaula.digital/app/routes/authorize.py index da8a20e..64dcacd 100644 --- a/id.saladeaula.digital/app/routes/authorize.py +++ b/id.saladeaula.digital/app/routes/authorize.py @@ -7,6 +7,7 @@ from authlib.oauth2.rfc6749 import errors from aws_lambda_powertools import Logger from aws_lambda_powertools.event_handler import Response from aws_lambda_powertools.event_handler.api_gateway import Router +from jose.exceptions import JWTError from jose_ import verify_jwt from oauth2 import server @@ -18,33 +19,21 @@ logger = Logger(__name__) @router.get('/authorize') def authorize(): current_event = router.current_event - cookies = _parse_cookies(current_event.get('cookies', [])) # type: ignore + cookies = _parse_cookies(current_event.get('cookies', [])) id_token = cookies.get('id_token') - continue_url = quote( - urlunparse( - ParseResult( - scheme='', - netloc='', - path=current_event.path, - params='', - query=urlencode(current_event.query_string_parameters), - fragment='', - ) - ), - safe='', + continue_url = _build_continue_url( + current_event.path, + current_event.query_string_parameters, ) + login_url = f'/login?continue={continue_url}' - if not id_token: - return Response( - status_code=HTTPStatus.FOUND, - headers={'Location': login_url}, - ) - try: + if not id_token: + raise ValueError('Missing id_token') + user = verify_jwt(id_token) - except Exception as exc: - logger.exception(exc) + except (ValueError, JWTError): return Response( status_code=HTTPStatus.FOUND, headers={'Location': login_url}, @@ -82,3 +71,11 @@ def _parse_cookies(cookies: list[str] | None) -> dict[str, str]: parsed_cookies.update({k: morsel.value for k, morsel in c.items()}) return parsed_cookies + + +def _build_continue_url( + path: str, + query_string_parameters: dict, +) -> str: + query = urlencode(query_string_parameters) + return quote(urlunparse(ParseResult('', '', path, '', query, '')), safe='') diff --git a/id.saladeaula.digital/tests/routes/test_token.py b/id.saladeaula.digital/tests/routes/test_token.py index cb27ccf..c74b627 100644 --- a/id.saladeaula.digital/tests/routes/test_token.py +++ b/id.saladeaula.digital/tests/routes/test_token.py @@ -1,5 +1,4 @@ import json -import pprint from http import HTTPMethod, HTTPStatus from urllib.parse import urlencode @@ -36,10 +35,11 @@ def test_token( ), lambda_context, ) - assert r['statusCode'] == HTTPStatus.OK - data = json.loads(r['body']) + auth_token = json.loads(r['body']) + + assert r['statusCode'] == HTTPStatus.OK + assert auth_token['expires_in'] == 900 - # print(data) r = dynamodb_persistence_layer.query( key_cond_expr='#pk = :pk', expr_attr_name={ @@ -49,7 +49,7 @@ def test_token( ':pk': 'OAUTH2#TOKEN', }, ) - # pprint.pp(r['items']) + assert len(r['items']) == 2 r = app.lambda_handler( http_api_proxy( @@ -61,7 +61,7 @@ def test_token( body=urlencode( { 'grant_type': 'refresh_token', - 'refresh_token': data['refresh_token'], + 'refresh_token': auth_token['refresh_token'], 'client_id': client_id, } ), @@ -70,4 +70,14 @@ def test_token( ) assert r['statusCode'] == HTTPStatus.OK - # print(r['body']) + + r = dynamodb_persistence_layer.query( + key_cond_expr='#pk = :pk', + expr_attr_name={ + '#pk': 'id', + }, + expr_attr_values={ + ':pk': 'OAUTH2#TOKEN', + }, + ) + assert len(r['items']) == 3