wip
This commit is contained in:
@@ -11,7 +11,7 @@ from aws_lambda_powertools.event_handler.exceptions import ServiceError
|
||||
from aws_lambda_powertools.logging import correlation_paths
|
||||
from aws_lambda_powertools.utilities.typing import LambdaContext
|
||||
|
||||
from middlewares import AuthorizerMiddleware
|
||||
from middlewares import AuthorizerMiddleware, TenantMiddleware
|
||||
from routes import courses, enrollments, lookup, me, orders, users, webhooks
|
||||
|
||||
DEBUG = os.getenv('LOG_LEVEL') == 'DEBUG'
|
||||
@@ -19,7 +19,7 @@ DEBUG = os.getenv('LOG_LEVEL') == 'DEBUG'
|
||||
tracer = Tracer()
|
||||
logger = Logger(__name__)
|
||||
app = APIGatewayHttpResolver(enable_validation=True, debug=DEBUG)
|
||||
app.use(middlewares=[AuthorizerMiddleware()])
|
||||
app.use(middlewares=[AuthorizerMiddleware(), TenantMiddleware()])
|
||||
app.include_router(courses.router, prefix='/courses')
|
||||
app.include_router(enrollments.router, prefix='/enrollments')
|
||||
app.include_router(orders.router, prefix='/orders')
|
||||
|
||||
@@ -24,6 +24,7 @@ Example
|
||||
"""
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from aws_lambda_powertools import Logger, Tracer
|
||||
@@ -53,31 +54,46 @@ collect = DynamoDBCollection(user_layer)
|
||||
@tracer.capture_lambda_handler
|
||||
@logger.inject_lambda_context
|
||||
@event_source(data_class=APIGatewayAuthorizerEventV2)
|
||||
def lambda_handler(event: APIGatewayAuthorizerEventV2, context: LambdaContext):
|
||||
def lambda_handler(event: APIGatewayAuthorizerEventV2, context: LambdaContext) -> dict:
|
||||
bearer = _parse_bearer_token(event.headers.get('authorization', ''))
|
||||
|
||||
if not bearer:
|
||||
return APIGatewayAuthorizerResponseV2(authorize=False).asdict()
|
||||
|
||||
kwargs = asdict(_authorizer(bearer))
|
||||
return APIGatewayAuthorizerResponseV2(**kwargs).asdict()
|
||||
attrs = _authorizer(bearer).asdict()
|
||||
return APIGatewayAuthorizerResponseV2(**attrs).asdict()
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
API_KEY = 'API_KEY'
|
||||
USER_TOKEN = 'USER_TOKEN'
|
||||
class AuthFlowType(str, Enum):
|
||||
USER_AUTH = 'USER_AUTH'
|
||||
API_AUTH = 'API_AUTH'
|
||||
|
||||
|
||||
@dataclass
|
||||
class BearerToken:
|
||||
auth_type: TokenType
|
||||
auth_flow_type: AuthFlowType
|
||||
token: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Authorizer:
|
||||
authorize: bool = False
|
||||
context: dict | None = None
|
||||
context: dict[str, Any] | None = None
|
||||
auth_flow_type: AuthFlowType = AuthFlowType.USER_AUTH
|
||||
|
||||
def asdict(self) -> dict:
|
||||
data = asdict(self)
|
||||
auth_flow_type = data.pop('auth_flow_type')
|
||||
|
||||
# If authorization is enabled, add `auth_flow_type` to the context
|
||||
if self.authorize:
|
||||
data['context'].update(auth_flow_type=auth_flow_type)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _get_apikey(token: str) -> dict[str, dict | str]:
|
||||
return collect.get_item(KeyPair('apikey', token))
|
||||
|
||||
|
||||
def _authorizer(bearer: BearerToken) -> Authorizer:
|
||||
@@ -95,22 +111,13 @@ def _authorizer(bearer: BearerToken) -> Authorizer:
|
||||
An Authorizer object with the appropriate authorization status and context.
|
||||
"""
|
||||
try:
|
||||
match bearer.auth_type:
|
||||
case TokenType.USER_TOKEN:
|
||||
user = get_user(bearer.token, idp_client=idp_client)
|
||||
return Authorizer(True, {'user': user})
|
||||
case TokenType.API_KEY:
|
||||
apikey = collect.get_item(KeyPair('apikey', bearer.token))
|
||||
return Authorizer(
|
||||
True,
|
||||
pick(
|
||||
(
|
||||
'user',
|
||||
'tenant',
|
||||
),
|
||||
apikey,
|
||||
),
|
||||
)
|
||||
if bearer.auth_flow_type == AuthFlowType.USER_AUTH:
|
||||
user = get_user(bearer.token, idp_client)
|
||||
return Authorizer(True, {'user': user})
|
||||
|
||||
apikey = _get_apikey(bearer.token)
|
||||
context = pick(('tenant', 'user'), apikey)
|
||||
return Authorizer(True, context, AuthFlowType.API_AUTH)
|
||||
except Exception:
|
||||
return Authorizer()
|
||||
|
||||
@@ -118,14 +125,14 @@ def _authorizer(bearer: BearerToken) -> Authorizer:
|
||||
def _parse_bearer_token(s: str) -> BearerToken | None:
|
||||
"""Parses and identifies a bearer token as either an API key or a user token."""
|
||||
try:
|
||||
_, bearer_token = s.split(' ')
|
||||
_, token = s.split(' ')
|
||||
|
||||
if bearer_token.startswith(APIKEY_PREFIX):
|
||||
if token.startswith(APIKEY_PREFIX):
|
||||
return BearerToken(
|
||||
TokenType.API_KEY,
|
||||
bearer_token.removeprefix(APIKEY_PREFIX),
|
||||
AuthFlowType.API_AUTH,
|
||||
token.removeprefix(APIKEY_PREFIX),
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
return BearerToken(TokenType.USER_TOKEN, bearer_token)
|
||||
return BearerToken(AuthFlowType.USER_AUTH, token)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
class UnauthorizedError(Exception): ...
|
||||
|
||||
|
||||
def get_user(access_token: str, *, idp_client) -> dict | None:
|
||||
def get_user(access_token: str, /, idp_client) -> dict[str, str]:
|
||||
"""Gets the user attributes and metadata for a user."""
|
||||
try:
|
||||
user = idp_client.get_user(AccessToken=access_token)
|
||||
|
||||
@@ -12,15 +12,21 @@ from aws_lambda_powertools.shared.functions import (
|
||||
from layercake.dateutils import now, ttl
|
||||
from layercake.dynamodb import ComposeKey, DynamoDBCollection, KeyPair
|
||||
from layercake.funcs import pick
|
||||
from pydantic import UUID4, BaseModel, Field
|
||||
from pydantic import UUID4, BaseModel, EmailStr, Field
|
||||
|
||||
from auth import AuthFlowType
|
||||
|
||||
LOG_RETENTION_DAYS = 365 * 2 # 2 years
|
||||
|
||||
|
||||
class AuthenticatedUser(BaseModel):
|
||||
id: str = Field(alias='custom:user_id')
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class CognitoUser(User):
|
||||
id: str = Field(alias='custom:user_id')
|
||||
email_verified: bool
|
||||
sub: UUID4
|
||||
|
||||
@@ -33,12 +39,31 @@ class AuthorizerMiddleware(BaseMiddlewareHandler):
|
||||
) -> Response:
|
||||
# Gets the Lambda authorizer associated with the current API Gateway event.
|
||||
# You can check the file `auth.py` for more details.
|
||||
authorizer = app.current_event.request_context.authorizer.get_lambda
|
||||
context = app.current_event.request_context.authorizer.get_lambda
|
||||
auth_flow_type = context.get('auth_flow_type')
|
||||
|
||||
if 'user' in authorizer:
|
||||
user = authorizer['user']
|
||||
app.append_context(authenticated_user=AuthenticatedUser(**user))
|
||||
if not auth_flow_type:
|
||||
return next_middleware(app)
|
||||
|
||||
cls = {
|
||||
AuthFlowType.USER_AUTH: CognitoUser,
|
||||
AuthFlowType.API_AUTH: User,
|
||||
}.get(auth_flow_type)
|
||||
|
||||
if cls:
|
||||
app.append_context(user=cls(**context['user']))
|
||||
|
||||
return next_middleware(app)
|
||||
|
||||
|
||||
class TenantMiddleware(BaseMiddlewareHandler):
|
||||
def handler(
|
||||
self,
|
||||
app: APIGatewayHttpResolver,
|
||||
next_middleware: NextMiddleware,
|
||||
) -> Response:
|
||||
context = app.current_event.request_context.authorizer.get_lambda
|
||||
auth_flow_type = context.get('auth_flow_type')
|
||||
return next_middleware(app)
|
||||
|
||||
|
||||
@@ -78,11 +103,10 @@ class AuditLogMiddleware(BaseMiddlewareHandler):
|
||||
app: APIGatewayHttpResolver,
|
||||
next_middleware: NextMiddleware,
|
||||
) -> Response:
|
||||
collect = self.collect
|
||||
user = app.context.get('user')
|
||||
req_context = app.current_event.request_context
|
||||
ip_addr = req_context.http.source_ip
|
||||
response = next_middleware(app)
|
||||
user = app.context.get('authenticated_user')
|
||||
request_ctx = app.current_event.request_context
|
||||
ip_addr = request_ctx.http.source_ip
|
||||
|
||||
# Successful request
|
||||
if 200 <= response.status_code < 300 and user:
|
||||
@@ -98,7 +122,7 @@ class AuditLogMiddleware(BaseMiddlewareHandler):
|
||||
else None
|
||||
)
|
||||
|
||||
collect.put_item(
|
||||
self.collect.put_item(
|
||||
key=KeyPair(
|
||||
pk=ComposeKey(user.id, prefix='logs'),
|
||||
sk=now_.isoformat(),
|
||||
|
||||
@@ -8,7 +8,7 @@ from layercake.dynamodb import (
|
||||
|
||||
import konviva
|
||||
from boto3clients import dynamodb_client
|
||||
from middlewares import AuthenticatedUser
|
||||
from middlewares import User
|
||||
from settings import USER_TABLE
|
||||
|
||||
router = Router()
|
||||
@@ -21,7 +21,7 @@ LIMIT = 25
|
||||
|
||||
@router.get('/', include_in_schema=False)
|
||||
def me():
|
||||
user: AuthenticatedUser = router.context['authenticated_user']
|
||||
user: User = router.context['user']
|
||||
acls = collect.get_items(
|
||||
KeyPair(user.id, PrefixKey('acls')),
|
||||
limit=LIMIT,
|
||||
@@ -39,7 +39,7 @@ def me():
|
||||
|
||||
@router.get('/konviva', include_in_schema=False)
|
||||
def konviva_():
|
||||
user: AuthenticatedUser = router.context['authenticated_user']
|
||||
user: User = router.context['user']
|
||||
token = konviva.token(user.email)
|
||||
|
||||
return {'redirect_uri': konviva.redirect_uri(token)}
|
||||
|
||||
@@ -59,6 +59,7 @@ class HttpApiProxy:
|
||||
'custom:user_id': '5OxmMjL-ujoR5IMGegQz',
|
||||
'sub': 'c4f30dbd-083e-4b84-aa50-c31afe9b9c01',
|
||||
},
|
||||
'auth_flow_type': 'USER_AUTH',
|
||||
},
|
||||
'jwt': {
|
||||
'claims': {'claim1': 'value1', 'claim2': 'value2'},
|
||||
|
||||
@@ -26,7 +26,8 @@ def test_bearer_jwt(lambda_context: LambdaContext):
|
||||
'sub': '58efed8d-d276-41a8-8502-4ab8b5a6415e',
|
||||
'name': 'pytest',
|
||||
'custom:user_id': '5OxmMjL-ujoR5IMGegQz',
|
||||
}
|
||||
},
|
||||
'auth_flow_type': 'USER_AUTH',
|
||||
},
|
||||
}
|
||||
|
||||
@@ -55,10 +56,11 @@ def test_bearer_apikey(
|
||||
'name': 'Sérgio R Siqueira',
|
||||
'email': 'sergio@somosbeta.com.br',
|
||||
},
|
||||
'auth_flow_type': 'API_AUTH',
|
||||
},
|
||||
}
|
||||
|
||||
# # This data was added from seeds
|
||||
# This data was added from seeds
|
||||
assert app.lambda_handler(
|
||||
{
|
||||
'headers': {
|
||||
@@ -75,11 +77,11 @@ def test_parse_bearer_token_api_key():
|
||||
)
|
||||
|
||||
assert bearer.token == '35433970-6857-4062-bb43-f71683b2f68e' # type: ignore
|
||||
assert bearer.auth_type == 'API_KEY' # type: ignore
|
||||
assert bearer.auth_flow_type == 'API_AUTH' # type: ignore
|
||||
|
||||
|
||||
def test_parse_bearer_token_user_token():
|
||||
bearer = _parse_bearer_token('Bearer d977f5a2-0302-4dd2-87c7-57414264d27a')
|
||||
|
||||
assert bearer.token == 'd977f5a2-0302-4dd2-87c7-57414264d27a' # type: ignore
|
||||
assert bearer.auth_type == 'USER_TOKEN' # type: ignore
|
||||
assert bearer.auth_flow_type == 'USER_AUTH' # type: ignore
|
||||
|
||||
2
http-api/uv.lock
generated
2
http-api/uv.lock
generated
@@ -444,7 +444,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "layercake"
|
||||
version = "0.1.9"
|
||||
version = "0.1.11"
|
||||
source = { directory = "../layercake" }
|
||||
dependencies = [
|
||||
{ name = "aws-lambda-powertools", extra = ["all"] },
|
||||
|
||||
Reference in New Issue
Block a user