This commit is contained in:
2025-09-26 14:45:27 -03:00
parent eeeccaaaa8
commit 1b6e4b7b5e
13 changed files with 131 additions and 212 deletions

View File

@@ -7,6 +7,6 @@ OAUTH2_SCOPES_SUPPORTED: str = os.getenv('OAUTH2_SCOPES_SUPPORTED', '')
JWT_SECRET: str = os.environ.get('JWT_SECRET') # type: ignore
JWT_ALGORITHM = 'HS256'
JWT_EXP_SECONDS = 900 # 15 minutes
# JWT_EXP_SECONDS = 900 # 15 minutes
OAUTH2_REFRESH_TOKEN_EXPIRES_IN = 30 * 86400 # 30 days
SESSION_EXPIRES_IN = 86400 * 30 # 30 days

View File

@@ -17,7 +17,6 @@ from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN
from .client import OAuth2Client
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
logger = Logger(__name__)
@@ -121,7 +120,6 @@ class AuthorizationServer(oauth2.AuthorizationServer):
self,
client_id: str,
):
"""Query OAuth client by client_id."""
client = self._persistence_layer.collection.get_item(
KeyPair(
pk='OAUTH2',

View File

@@ -1,4 +1,5 @@
import time
from dataclasses import dataclass
from authlib.oauth2.rfc6749 import (
AuthorizationCodeMixin,
@@ -8,6 +9,17 @@ from authlib.oauth2.rfc6749 import (
from layercake.dateutils import fromisoformat, now
@dataclass(frozen=True)
class User:
id: str
name: str
email: str
email_verified: bool = False
def get_user_id(self):
return self.id
class OAuth2AuthorizationCode(AuthorizationCodeMixin):
def __init__(
self,
@@ -75,8 +87,8 @@ class OAuth2Token(TokenMixin):
self.access_token = access_token
self.refresh_token = refresh_token
def get_user(self) -> dict:
return self.user
def get_user(self) -> User:
return User(**self.user)
def check_client(self, client: ClientMixin) -> bool:
return self.client_id == client.get_client_id()

View File

@@ -1,9 +1,6 @@
import time
from dataclasses import dataclass
from authlib.common.security import generate_token
from authlib.common.urls import add_params_to_uri
from authlib.jose import jwt
from authlib.jose import JsonWebKey
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
from authlib.oauth2.rfc6750 import BearerTokenGenerator
@@ -31,22 +28,14 @@ from integrations.apigateway_oauth2.authorization_server import (
from integrations.apigateway_oauth2.tokens import (
OAuth2AuthorizationCode,
OAuth2Token,
User,
)
from util import read_file_path
logger = Logger(__name__)
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
@dataclass(frozen=True)
class User:
id: str
name: str
email: str
email_verified: bool = False
def get_user_id(self):
return self.id
private_key = read_file_path('private.pem')
private_jwk = JsonWebKey.import_key(private_key)
class OpenIDCode(OpenIDCode_):
@@ -62,11 +51,10 @@ class OpenIDCode(OpenIDCode_):
def get_jwt_config(self, grant):
return {
'key': read_file_path('private.pem'),
'key': private_key,
'alg': 'RS256',
'kid': 'test',
'iss': ISSUER,
'exp': 3600 * 2,
'exp': 3600,
}
def generate_user_info(self, user: User, scope: str) -> UserInfo:
@@ -234,7 +222,7 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
)
transact.delete(
key=KeyPair(
pk=user.get('id'),
pk=user.get_user_id(),
sk=f'SESSION#REFRESH_TOKEN#{token}',
)
)
@@ -320,104 +308,9 @@ class IssuerParameter(rfc9207.IssuerParameter):
return ISSUER
class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
def get_jwks(self):
return read_file_path('private.pem')
def access_token_generator(self, client, grant_type, user, scope):
now = int(time.time())
expires_in = now + self._get_expires_in(client, grant_type)
token_data = {
'iss': self.issuer,
'exp': expires_in,
'client_id': client.get_client_id(),
'iat': now,
'jti': self.get_jti(client, grant_type, user, scope),
'scope': scope,
}
# In cases of access tokens obtained through grants where a resource owner is
# involved, such as the authorization code grant, the value of 'sub' SHOULD
# correspond to the subject identifier of the resource owner.
if user:
token_data['sub'] = user.get_user_id()
# In cases of access tokens obtained through grants where no resource owner is
# involved, such as the client credentials grant, the value of 'sub' SHOULD
# correspond to an identifier the authorization server uses to indicate the
# client application.
else:
token_data['sub'] = client.get_client_id()
# If the request includes a 'resource' parameter (as defined in [RFC8707]), the
# resulting JWT access token 'aud' claim SHOULD have the same value as the
# 'resource' parameter in the request.
# TODO: Implement this with RFC8707
if False: # pragma: no cover
...
# If the request does not include a 'resource' parameter, the authorization
# server MUST use a default resource indicator in the 'aud' claim. If a 'scope'
# parameter is present in the request, the authorization server SHOULD use it to
# infer the value of the default resource indicator to be used in the 'aud'
# claim. The mechanism through which scopes are associated with default resource
# indicator values is outside the scope of this specification.
else:
token_data['aud'] = self.get_audiences(client, user, scope)
# If the values in the 'scope' parameter refer to different default resource
# indicator values, the authorization server SHOULD reject the request with
# 'invalid_scope' as described in Section 4.1.2.1 of [RFC6749].
# TODO: Implement this with RFC8707
if auth_time := self.get_auth_time(user):
token_data['auth_time'] = auth_time
# The meaning and processing of acr Claim Values is out of scope for this
# specification.
if acr := self.get_acr(user):
token_data['acr'] = acr
# The definition of particular values to be used in the amr Claim is beyond the
# scope of this specification.
if amr := self.get_amr(user):
token_data['amr'] = amr
# Authorization servers MAY return arbitrary attributes not defined in any
# existing specification, as long as the corresponding claim names are collision
# resistant or the access tokens are meant to be used only within a private
# subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details.
token_data.update(self.get_extra_claims(client, grant_type, user, scope))
# This specification registers the 'application/at+jwt' media type, which can
# be used to indicate that the content is a JWT access token. JWT access tokens
# MUST include this media type in the 'typ' header parameter to explicitly
# declare that the JWT represents an access token complying with this profile.
# Per the definition of 'typ' in Section 4.1.9 of [RFC7515], it is RECOMMENDED
# that the 'application/' prefix be omitted. Therefore, the 'typ' value used
# SHOULD be 'at+jwt'.
header = {'alg': self.alg, 'typ': 'at+jwt', 'kid': 'k1'}
access_token = jwt.encode(
header,
token_data,
key=self.get_jwks(),
check=False,
)
return access_token.decode()
GRANT_TYPES_EXPIRES_IN = {
'refresh_token': 600,
'authorization_code': 60 * 10, # 10 minutes
'refresh_token': 86_400 * 7, # 7 days
}
@@ -434,6 +327,16 @@ def create_token_generator(length: int = 42):
return token_generator
class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
def get_jwks(self) -> dict[str, list]: # type: ignore
"""Return the JWKs that will be used to sign the JWT access token."""
return {
'keys': [
private_jwk.as_dict(is_private=True),
]
}
server = AuthorizationServer(persistence_layer=dyn)
server.register_grant(
AuthorizationCodeGrant,

View File

@@ -4,14 +4,13 @@ from aws_lambda_powertools.event_handler.api_gateway import Router
from util import read_file_path
router = Router()
public_jwk = JsonWebKey.import_key(read_file_path('public.pem'), {'kty': 'RSA'})
public_jwk = JsonWebKey.import_key(read_file_path('public.pem'))
@router.get('/.well-known/jwks.json')
def jwks():
key = public_jwk.as_dict()
key['use'] = 'sig'
key['kid'] = 'k1'
return {'keys': [key]}
return {
'keys': [
public_jwk.as_dict(),
]
}

View File

@@ -19,8 +19,8 @@ from config import (
ISSUER,
JWT_ALGORITHM,
JWT_SECRET,
OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
OAUTH2_TABLE,
SESSION_EXPIRES_IN,
)
router = Router()
@@ -46,7 +46,7 @@ def session(
http_only=True,
secure=True,
same_site=None,
max_age=OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
max_age=SESSION_EXPIRES_IN,
)
],
)
@@ -80,7 +80,7 @@ def _get_user(username: str) -> tuple[str, str]:
def new_session(sub: str) -> str:
session_id = str(uuid4())
now_ = now()
exp = ttl(start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN)
exp = ttl(start_dt=now_, seconds=SESSION_EXPIRES_IN)
token = jwt.encode(
{
'sid': session_id,

View File

@@ -3,10 +3,10 @@ import os
ROOT = os.path.abspath(os.path.dirname(__file__))
def get_file_path(name):
def get_file_path(name: str) -> str:
return os.path.join(ROOT, name)
def read_file_path(name):
def read_file_path(name: str) -> str:
with open(get_file_path(name)) as f:
return f.read()