add claim
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
AuthorizationCodeMixin,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from functools import reduce
|
||||
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.common.urls import add_params_to_uri
|
||||
from authlib.jose import JsonWebKey
|
||||
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
|
||||
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
|
||||
from authlib.oauth2.rfc6749.hooks import hooked
|
||||
from authlib.oauth2.rfc6749.util import scope_to_list
|
||||
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
||||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||
from authlib.oauth2.rfc9068 import JWTBearerTokenGenerator as JWTBearerTokenGenerator_
|
||||
@@ -46,6 +49,15 @@ GRANT_TYPES_EXPIRES_IN = {
|
||||
}
|
||||
|
||||
|
||||
def get_user_scope(user_id: str) -> set:
|
||||
items = dyn.collection.query(
|
||||
KeyPair(pk=user_id, sk='SCOPE#'),
|
||||
).get('items', [])
|
||||
scope = reduce(lambda acc, cur: acc + scope_to_list(cur['scope']), items, [])
|
||||
|
||||
return OAUTH2_DEFAULT_SCOPES | set(scope)
|
||||
|
||||
|
||||
class OpenIDCode(OpenIDCode_):
|
||||
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
|
||||
if not request.payload:
|
||||
@@ -182,16 +194,10 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
authorization_code: OAuth2AuthorizationCode,
|
||||
) -> User:
|
||||
"""Authenticate the user related to this authorization_code."""
|
||||
user = dyn.collection.get_items(
|
||||
TransactKey(authorization_code.user_id)
|
||||
+ SortKey('0')
|
||||
+ SortKey(
|
||||
sk='SCOPE',
|
||||
path_spec='scope',
|
||||
rename_key='scope',
|
||||
),
|
||||
user = dyn.collection.get_item(
|
||||
KeyPair(pk=authorization_code.user_id, sk='0'),
|
||||
)
|
||||
scope = set(user.get('scope', [])) | OAUTH2_DEFAULT_SCOPES
|
||||
scope = get_user_scope(authorization_code.user_id)
|
||||
|
||||
return User(
|
||||
**pick(('id', 'name', 'email', 'email_verified'), user),
|
||||
@@ -395,6 +401,13 @@ class JWTBearerTokenGenerator(JWTBearerTokenGenerator_):
|
||||
]
|
||||
}
|
||||
|
||||
def get_extra_claims(self, client, grant_type, user, scope):
|
||||
return {
|
||||
'name': user.name,
|
||||
'email': user.email,
|
||||
'email_verified': user.email_verified,
|
||||
}
|
||||
|
||||
|
||||
server = AuthorizationServer(
|
||||
persistence_layer=dyn,
|
||||
|
||||
@@ -9,11 +9,11 @@ from aws_lambda_powertools.event_handler.exceptions import (
|
||||
ServiceError,
|
||||
)
|
||||
from joserfc.errors import JoseError
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair, SortKey
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
||||
|
||||
from boto3clients import dynamodb_client
|
||||
from config import OAUTH2_DEFAULT_SCOPES, USER_TABLE
|
||||
from oauth2 import server
|
||||
from config import USER_TABLE
|
||||
from oauth2 import get_user_scope, server
|
||||
from util import parse_cookies
|
||||
|
||||
router = Router()
|
||||
@@ -44,7 +44,7 @@ def authorize():
|
||||
request=router.current_event,
|
||||
end_user=user_id,
|
||||
)
|
||||
user_scopes = _user_scopes(user_id)
|
||||
user_scopes = get_user_scope(user_id)
|
||||
client_scopes = set(scope_to_list(grant.client.scope))
|
||||
|
||||
# Deny authorization if user lacks scopes requested by client
|
||||
@@ -67,18 +67,3 @@ def authorize():
|
||||
)
|
||||
else:
|
||||
return response
|
||||
|
||||
|
||||
def _user_scopes(user_id: str) -> set:
|
||||
return OAUTH2_DEFAULT_SCOPES | set(
|
||||
scope_to_list(
|
||||
dyn.collection.get_item(
|
||||
KeyPair(
|
||||
pk=user_id,
|
||||
sk=SortKey(sk='SCOPE', path_spec='scope'),
|
||||
),
|
||||
raise_on_error=False,
|
||||
default='',
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -37,20 +37,22 @@ def test_authorize(
|
||||
lambda_context,
|
||||
)
|
||||
|
||||
assert 'Location' in r['headers']
|
||||
print(r)
|
||||
|
||||
r = dynamodb_persistence_layer.query(
|
||||
key_cond_expr='#pk = :pk',
|
||||
expr_attr_name={
|
||||
'#pk': 'id',
|
||||
},
|
||||
expr_attr_values={
|
||||
':pk': 'OAUTH2#CODE',
|
||||
},
|
||||
)
|
||||
# assert 'Location' in r['headers']
|
||||
|
||||
# One item was added from seeds
|
||||
assert len(r['items']) == 3
|
||||
# r = dynamodb_persistence_layer.query(
|
||||
# key_cond_expr='#pk = :pk',
|
||||
# expr_attr_name={
|
||||
# '#pk': 'id',
|
||||
# },
|
||||
# expr_attr_values={
|
||||
# ':pk': 'OAUTH2#CODE',
|
||||
# },
|
||||
# )
|
||||
|
||||
# # One item was added from seeds
|
||||
# assert len(r['items']) == 3
|
||||
|
||||
|
||||
def test_forbidden(
|
||||
|
||||
Reference in New Issue
Block a user