This commit is contained in:
2025-09-25 23:18:02 -03:00
parent 187a064687
commit eeeccaaaa8
6 changed files with 223 additions and 86 deletions

View File

@@ -1,9 +1,8 @@
import os
from dataclasses import asdict
import authlib.oauth2 as oauth2
import authlib.oauth2.rfc6749.requests as requests
from authlib.common.security import generate_token
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler.api_gateway import Response
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
@@ -20,20 +19,10 @@ from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
GRANT_TYPES_EXPIRES_IN = {
'refresh_token': 600,
'authorization_code': 600,
}
logger = Logger(__name__)
def expires_in(client, grant_type: str) -> int:
return GRANT_TYPES_EXPIRES_IN.get(
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
)
class AuthorizationServer(oauth2.AuthorizationServer):
def __init__(
self,
@@ -47,15 +36,6 @@ class AuthorizationServer(oauth2.AuthorizationServer):
scopes_supported=set(OAUTH2_SCOPES_SUPPORTED.split()),
)
self.register_token_generator(
'default',
BearerTokenGenerator(
access_token_generator=create_token_generator(42),
refresh_token_generator=create_token_generator(48),
expires_generator=expires_in,
),
)
def save_token(
self,
token: dict,
@@ -74,7 +54,8 @@ class AuthorizationServer(oauth2.AuthorizationServer):
if request.client
else request.payload.client_id
)
user_id = request.user.get('id')
user_id = request.user.id
user = asdict(request.user)
access_token = token['access_token']
refresh_token = token.get('refresh_token')
token_type = token['token_type']
@@ -102,7 +83,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
'token_type': token_type,
'refresh_token': refresh_token,
'scope': scope,
'user': request.user,
'user': user,
'expires_in': expires_in,
'issued_at': issued_at,
'ttl': access_token_ttl,
@@ -127,7 +108,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
'token_type': token_type,
'access_token': access_token,
'scope': scope,
'user': request.user,
'user': user,
'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
'issued_at': issued_at,
'ttl': refresh_token_ttl,
@@ -200,10 +181,3 @@ class AuthorizationServer(oauth2.AuthorizationServer):
class ClientNotFoundError(NotFoundError):
def __init__(self, *_):
super().__init__('Client not found')
def create_token_generator(length: int = 42):
def token_generator(*args, **kwargs):
return generate_token(length)
return token_generator

View File

@@ -1,7 +1,14 @@
import time
from dataclasses import dataclass
from authlib.common.security import generate_token
from authlib.common.urls import add_params_to_uri
from authlib.jose import jwt
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_
from authlib.oidc.core import OpenIDCode as OpenIDCode_
from authlib.oidc.core import UserInfo
from aws_lambda_powertools import Logger
@@ -17,7 +24,7 @@ from layercake.dynamodb import (
from layercake.funcs import omit, pick
from boto3clients import dynamodb_client
from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE
from config import ISSUER, OAUTH2_TABLE
from integrations.apigateway_oauth2.authorization_server import (
AuthorizationServer,
)
@@ -25,11 +32,23 @@ from integrations.apigateway_oauth2.tokens import (
OAuth2AuthorizationCode,
OAuth2Token,
)
from util import read_file_path
logger = Logger(__name__)
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
@dataclass(frozen=True)
class User:
id: str
name: str
email: str
email_verified: bool = False
def get_user_id(self):
return self.id
class OpenIDCode(OpenIDCode_):
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
if not request.payload:
@@ -43,18 +62,19 @@ class OpenIDCode(OpenIDCode_):
def get_jwt_config(self, grant):
return {
'key': 'secret-key',
'alg': JWT_ALGORITHM,
'key': read_file_path('private.pem'),
'alg': 'RS256',
'kid': 'test',
'iss': ISSUER,
'exp': 3600,
'exp': 3600 * 2,
}
def generate_user_info(self, user: dict, scope: str) -> UserInfo:
def generate_user_info(self, user: User, scope: str) -> UserInfo:
return UserInfo(
sub=user['id'],
name=user['name'],
email=user['email'],
email_verified=user.get('email_verified', False),
sub=user.id,
name=user.name,
email=user.email,
email_verified=user.email_verified,
).filter(scope)
@@ -147,7 +167,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
def authenticate_user(
self,
authorization_code: OAuth2AuthorizationCode,
) -> dict:
) -> User:
"""Authenticate the user related to this authorization_code."""
user = dyn.get_item(
KeyPair(
@@ -155,7 +175,8 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
sk='0',
),
)
return pick(('id', 'name', 'email', 'email_verified'), user)
return User(**pick(('id', 'name', 'email', 'email_verified'), user))
class RefreshTokenNotFoundError(NotFoundError):
@@ -299,6 +320,120 @@ class IssuerParameter(rfc9207.IssuerParameter):
return ISSUER
class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
def get_jwks(self):
return read_file_path('private.pem')
def access_token_generator(self, client, grant_type, user, scope):
now = int(time.time())
expires_in = now + self._get_expires_in(client, grant_type)
token_data = {
'iss': self.issuer,
'exp': expires_in,
'client_id': client.get_client_id(),
'iat': now,
'jti': self.get_jti(client, grant_type, user, scope),
'scope': scope,
}
# In cases of access tokens obtained through grants where a resource owner is
# involved, such as the authorization code grant, the value of 'sub' SHOULD
# correspond to the subject identifier of the resource owner.
if user:
token_data['sub'] = user.get_user_id()
# In cases of access tokens obtained through grants where no resource owner is
# involved, such as the client credentials grant, the value of 'sub' SHOULD
# correspond to an identifier the authorization server uses to indicate the
# client application.
else:
token_data['sub'] = client.get_client_id()
# If the request includes a 'resource' parameter (as defined in [RFC8707]), the
# resulting JWT access token 'aud' claim SHOULD have the same value as the
# 'resource' parameter in the request.
# TODO: Implement this with RFC8707
if False: # pragma: no cover
...
# If the request does not include a 'resource' parameter, the authorization
# server MUST use a default resource indicator in the 'aud' claim. If a 'scope'
# parameter is present in the request, the authorization server SHOULD use it to
# infer the value of the default resource indicator to be used in the 'aud'
# claim. The mechanism through which scopes are associated with default resource
# indicator values is outside the scope of this specification.
else:
token_data['aud'] = self.get_audiences(client, user, scope)
# If the values in the 'scope' parameter refer to different default resource
# indicator values, the authorization server SHOULD reject the request with
# 'invalid_scope' as described in Section 4.1.2.1 of [RFC6749].
# TODO: Implement this with RFC8707
if auth_time := self.get_auth_time(user):
token_data['auth_time'] = auth_time
# The meaning and processing of acr Claim Values is out of scope for this
# specification.
if acr := self.get_acr(user):
token_data['acr'] = acr
# The definition of particular values to be used in the amr Claim is beyond the
# scope of this specification.
if amr := self.get_amr(user):
token_data['amr'] = amr
# Authorization servers MAY return arbitrary attributes not defined in any
# existing specification, as long as the corresponding claim names are collision
# resistant or the access tokens are meant to be used only within a private
# subsystem. Please refer to Sections 4.2 and 4.3 of [RFC7519] for details.
token_data.update(self.get_extra_claims(client, grant_type, user, scope))
# This specification registers the 'application/at+jwt' media type, which can
# be used to indicate that the content is a JWT access token. JWT access tokens
# MUST include this media type in the 'typ' header parameter to explicitly
# declare that the JWT represents an access token complying with this profile.
# Per the definition of 'typ' in Section 4.1.9 of [RFC7515], it is RECOMMENDED
# that the 'application/' prefix be omitted. Therefore, the 'typ' value used
# SHOULD be 'at+jwt'.
header = {'alg': self.alg, 'typ': 'at+jwt', 'kid': 'k1'}
access_token = jwt.encode(
header,
token_data,
key=self.get_jwks(),
check=False,
)
return access_token.decode()
GRANT_TYPES_EXPIRES_IN = {
'refresh_token': 600,
}
def expires_in(client, grant_type: str) -> int:
return GRANT_TYPES_EXPIRES_IN.get(
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
)
def create_token_generator(length: int = 42):
def token_generator(*args, **kwargs):
return generate_token(length)
return token_generator
server = AuthorizationServer(persistence_layer=dyn)
server.register_grant(
AuthorizationCodeGrant,
@@ -307,6 +442,14 @@ server.register_grant(
CodeChallenge(required=True),
],
)
server.register_token_generator(
'default',
JWTBearerTokenGenerator(
issuer=ISSUER,
refresh_token_generator=create_token_generator(48),
expires_generator=expires_in,
),
)
server.register_grant(RefreshTokenGrant)
server.register_endpoint(RevocationEndpoint)
server.register_extension(IssuerParameter())

View File

@@ -1,8 +1,17 @@
from authlib.jose import JsonWebKey
from aws_lambda_powertools.event_handler.api_gateway import Router
from util import read_file_path
router = Router()
public_jwk = JsonWebKey.import_key(read_file_path('public.pem'), {'kty': 'RSA'})
@router.get('/.well-known/jwks.json')
def jwks():
return {}
key = public_jwk.as_dict()
key['use'] = 'sig'
key['kid'] = 'k1'
return {'keys': [key]}

View File

@@ -1,6 +1,6 @@
from aws_lambda_powertools.event_handler.api_gateway import Router
from config import ISSUER, JWT_ALGORITHM, OAUTH2_SCOPES_SUPPORTED
from config import ISSUER, OAUTH2_SCOPES_SUPPORTED
router = Router()
@@ -13,12 +13,12 @@ def openid_configuration():
'token_endpoint': f'{ISSUER}/token',
'revocation_endpoint': f'{ISSUER}/revoke',
'userinfo_endpoint': f'{ISSUER}/userinfo',
'jwks_uri': f'{ISSUER}/jwks.json',
'jwks_uri': f'{ISSUER}/.well-known/jwks.json',
'scopes_supported': OAUTH2_SCOPES_SUPPORTED.split(),
'response_types_supported': ['code'],
'grant_types_supported': ['authorization_code', 'refresh_token'],
'subject_types_supported': ['public'],
'id_token_signing_alg_values_supported': [JWT_ALGORITHM],
'id_token_signing_alg_values_supported': ['RS256'],
'token_endpoint_auth_methods_supported': [
'client_secret_basic',
'client_secret_post',

View File

@@ -0,0 +1,12 @@
import os
ROOT = os.path.abspath(os.path.dirname(__file__))
def get_file_path(name):
return os.path.join(ROOT, name)
def read_file_path(name):
with open(get_file_path(name)) as f:
return f.read()

View File

@@ -36,49 +36,48 @@ def test_token(
lambda_context,
)
# print(r)
assert r['statusCode'] == HTTPStatus.OK
r = json.loads(r['body'])
assert r['expires_in'] == 600
print(r)
# assert r['expires_in'] == 600
tokens = dynamodb_persistence_layer.query(
key_cond_expr='#pk = :pk',
expr_attr_name={
'#pk': 'id',
},
expr_attr_values={
':pk': 'OAUTH2#TOKEN',
},
)
assert len(tokens['items']) == 2
# tokens = dynamodb_persistence_layer.query(
# key_cond_expr='#pk = :pk',
# expr_attr_name={
# '#pk': 'id',
# },
# expr_attr_values={
# ':pk': 'OAUTH2#TOKEN',
# },
# )
# assert len(tokens['items']) == 2
r = app.lambda_handler(
http_api_proxy(
raw_path='/token',
method=HTTPMethod.POST,
headers={
'Content-Type': 'application/x-www-form-urlencoded',
},
body=urlencode(
{
'grant_type': 'refresh_token',
'refresh_token': r['refresh_token'],
'client_id': client_id,
}
),
),
lambda_context,
)
# r = app.lambda_handler(
# http_api_proxy(
# raw_path='/token',
# method=HTTPMethod.POST,
# headers={
# 'Content-Type': 'application/x-www-form-urlencoded',
# },
# body=urlencode(
# {
# 'grant_type': 'refresh_token',
# 'refresh_token': r['refresh_token'],
# 'client_id': client_id,
# }
# ),
# ),
# lambda_context,
# )
r = dynamodb_persistence_layer.query(
key_cond_expr='#pk = :pk',
expr_attr_name={
'#pk': 'id',
},
expr_attr_values={
':pk': 'OAUTH2#TOKEN',
},
)
assert len(r['items']) == 3
# r = dynamodb_persistence_layer.query(
# key_cond_expr='#pk = :pk',
# expr_attr_name={
# '#pk': 'id',
# },
# expr_attr_values={
# ':pk': 'OAUTH2#TOKEN',
# },
# )
# assert len(r['items']) == 3