wip
This commit is contained in:
@@ -1,110 +1,89 @@
|
||||
from authlib.oauth2.rfc6749 import TokenMixin, grants
|
||||
from authlib.oauth2 import OAuth2Request
|
||||
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
|
||||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||
from authlib.oidc.core import OpenIDCode as OpenIDCode_
|
||||
from authlib.oidc.core import UserInfo
|
||||
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
|
||||
from layercake.dateutils import now, ttl
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
||||
from layercake.funcs import omit, pick
|
||||
from layercake.funcs import pick
|
||||
|
||||
from boto3clients import dynamodb_client
|
||||
from config import DYNAMODB_SORT_KEY, OAUTH2_TABLE
|
||||
from integrations.apigateway_oauth2 import (
|
||||
AuthorizationCode,
|
||||
from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE
|
||||
from integrations.apigateway_oauth2.authorization_server import (
|
||||
AuthorizationServer,
|
||||
OAuth2Client,
|
||||
)
|
||||
from integrations.apigateway_oauth2.tokens import (
|
||||
OAuth2AuthorizationCode,
|
||||
OAuth2Token,
|
||||
)
|
||||
|
||||
oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
|
||||
|
||||
|
||||
DUMMY_JWT_CONFIG = {
|
||||
'key': 'secret-key',
|
||||
'alg': 'HS256',
|
||||
'iss': 'https://authlib.org',
|
||||
'exp': 3600,
|
||||
}
|
||||
|
||||
|
||||
def create_save_token_func(persistence_layer: DynamoDBPersistenceLayer):
|
||||
def save_token(token, request) -> OAuth2Token:
|
||||
print('save_token')
|
||||
return OAuth2Token()
|
||||
|
||||
return save_token
|
||||
|
||||
|
||||
def create_query_client_func(persistence_layer: DynamoDBPersistenceLayer):
|
||||
class ClientNotFoundError(NotFoundError):
|
||||
def __init__(self, *_):
|
||||
super().__init__('Client not found')
|
||||
|
||||
def query_client(client_id) -> OAuth2Client:
|
||||
client = persistence_layer.collection.get_item(
|
||||
KeyPair(
|
||||
pk='OAUTH2_CLIENT',
|
||||
sk=f'CLIENT_ID#{client_id}',
|
||||
),
|
||||
exc_cls=ClientNotFoundError,
|
||||
)
|
||||
|
||||
_, client_id = client.get(DYNAMODB_SORT_KEY, '').split('#')
|
||||
|
||||
return OAuth2Client(
|
||||
client_id=client_id,
|
||||
client_secret=client['secret'],
|
||||
redirect_uris=client['redirect_uris'],
|
||||
response_types=client['response_types'],
|
||||
grant_types=client['grant_types'],
|
||||
token_endpoint_auth_method=client['token_endpoint_auth_method'],
|
||||
)
|
||||
|
||||
return query_client
|
||||
|
||||
|
||||
class OpenIDCode(OpenIDCode_):
|
||||
def exists_nonce(self, nonce, request):
|
||||
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
|
||||
if not request.payload:
|
||||
raise ValueError('Missing request payload')
|
||||
|
||||
nonce_ = oauth2_layer.get_item(
|
||||
KeyPair(
|
||||
f'OAUTH2_CODE#CLIENT_ID#{request.payload.client_id}', # type:ignore
|
||||
f'NONCE#{nonce}',
|
||||
)
|
||||
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
|
||||
)
|
||||
|
||||
return bool(nonce_)
|
||||
|
||||
def get_jwt_config(self, grant):
|
||||
return DUMMY_JWT_CONFIG
|
||||
return {
|
||||
'key': 'secret-key',
|
||||
'alg': JWT_ALGORITHM,
|
||||
'iss': ISSUER,
|
||||
'exp': 3600,
|
||||
}
|
||||
|
||||
def generate_user_info(self, user, scope):
|
||||
def generate_user_info(self, user: dict, scope: str) -> UserInfo:
|
||||
return UserInfo(
|
||||
sub=user.id,
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
sub=user['id'],
|
||||
name=user['name'],
|
||||
email=user['email'],
|
||||
email_verified=user.get('email_verified', False),
|
||||
).filter(scope)
|
||||
|
||||
|
||||
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none']
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = [
|
||||
'client_secret_basic',
|
||||
'client_secret_post',
|
||||
'none',
|
||||
]
|
||||
|
||||
def save_authorization_code(self, code: str, request):
|
||||
client_id: str = request.payload.client_id # type: ignore
|
||||
data: dict = request.payload.data # type: ignore
|
||||
user: dict = request.user # type: ignore
|
||||
def save_authorization_code(
|
||||
self,
|
||||
code: str,
|
||||
request: OAuth2Request,
|
||||
) -> None:
|
||||
if not request.payload:
|
||||
raise ValueError('Missing request payload')
|
||||
|
||||
if not request.user:
|
||||
raise ValueError('Missing request user')
|
||||
|
||||
client_id: str = request.payload.client_id
|
||||
data: dict = request.payload.data
|
||||
user: dict = request.user
|
||||
nonce: str | None = data.get('nonce')
|
||||
code_challenge: str | None = data.get('code_challenge')
|
||||
code_challenge_method: str | None = data.get('code_challenge_method')
|
||||
now_ = now()
|
||||
ttl_ = ttl(start_dt=now_, minutes=15)
|
||||
ttl_ = ttl(start_dt=now_, minutes=10)
|
||||
|
||||
with oauth2_layer.transact_writer() as transact:
|
||||
transact.put(
|
||||
item={
|
||||
'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}',
|
||||
'id': 'OAUTH2#CODE',
|
||||
'sk': f'CODE#{code}',
|
||||
'redirect_uri': request.payload.redirect_uri, # type: ignore
|
||||
'scope': request.payload.scope, # type: ignore
|
||||
'redirect_uri': request.payload.redirect_uri,
|
||||
'response_type': request.payload.response_type,
|
||||
'scope': request.payload.scope,
|
||||
'client_id': client_id,
|
||||
'user_id': user['id'],
|
||||
'nonce': nonce,
|
||||
'code_challenge': code_challenge,
|
||||
@@ -117,56 +96,90 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
if nonce:
|
||||
transact.put(
|
||||
item={
|
||||
'id': f'OAUTH2_CODE#CLIENT_ID#{client_id}',
|
||||
'id': 'OAUTH2#CODE',
|
||||
'sk': f'NONCE#{nonce}',
|
||||
'client_id': client_id,
|
||||
'code': code,
|
||||
'created_at': now_,
|
||||
'ttl': ttl_,
|
||||
},
|
||||
)
|
||||
|
||||
def query_authorization_code(self, code, client):
|
||||
client_id = client.get_client_id()
|
||||
def query_authorization_code(
|
||||
self,
|
||||
code: str,
|
||||
client: ClientMixin,
|
||||
) -> OAuth2AuthorizationCode:
|
||||
auth_code = oauth2_layer.get_item(
|
||||
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
|
||||
)
|
||||
|
||||
return OAuth2AuthorizationCode(
|
||||
code=code,
|
||||
**auth_code,
|
||||
)
|
||||
|
||||
def delete_authorization_code(
|
||||
self,
|
||||
authorization_code: OAuth2AuthorizationCode,
|
||||
) -> None:
|
||||
oauth2_layer.delete_item(
|
||||
KeyPair(
|
||||
pk=f'OAUTH2_CODE#CLIENT_ID#{client_id}',
|
||||
sk=f'CODE#{code}',
|
||||
pk='OAUTH2#CODE',
|
||||
sk=f'CODE#{authorization_code.code}',
|
||||
),
|
||||
)
|
||||
|
||||
return AuthorizationCode(
|
||||
client_id=client_id,
|
||||
code=code,
|
||||
**omit(('id', 'sk'), auth_code),
|
||||
)
|
||||
|
||||
def delete_authorization_code(self, authorization_code):
|
||||
print('authorization_code')
|
||||
|
||||
def authenticate_user(self, authorization_code):
|
||||
def authenticate_user(
|
||||
self,
|
||||
authorization_code: OAuth2AuthorizationCode,
|
||||
) -> dict:
|
||||
user = oauth2_layer.get_item(
|
||||
KeyPair(
|
||||
pk=authorization_code.user_id,
|
||||
sk='0',
|
||||
),
|
||||
)
|
||||
return pick(('id', 'name', 'email'), user)
|
||||
return pick(('id', 'name', 'email', 'email_verified'), user)
|
||||
|
||||
|
||||
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||
TOKEN_ENDPOINT_AUTH_METHODS = ['client_secret_basic', 'client_secret_post', 'none']
|
||||
INCLUDE_NEW_REFRESH_TOKEN = True
|
||||
|
||||
def authenticate_refresh_token(self, refresh_token: str) -> TokenMixin: ...
|
||||
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
|
||||
token = oauth2_layer.get_item(
|
||||
KeyPair(
|
||||
pk='OAUTH2#TOKEN',
|
||||
sk=f'REFRESH_TOKEN#{refresh_token}',
|
||||
)
|
||||
)
|
||||
|
||||
def authenticate_user(self, refresh_token): ...
|
||||
return OAuth2Token(
|
||||
client_id=token['client_id'],
|
||||
scope=token['scope'],
|
||||
expires_in=int(token['expires_in']),
|
||||
issued_at=int(token['issued_at']),
|
||||
user=token['user'],
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
def revoke_old_credential(self, refresh_token: TokenMixin) -> None: ...
|
||||
def authenticate_user(self, refresh_token: TokenMixin):
|
||||
return refresh_token.get_user()
|
||||
|
||||
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
|
||||
refresh_token_ = getattr(refresh_token, 'refresh_token')
|
||||
|
||||
if refresh_token_:
|
||||
oauth2_layer.delete_item(
|
||||
KeyPair(
|
||||
pk='OAUTH2#TOKEN',
|
||||
sk=f'REFRESH_TOKEN#{refresh_token_}',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
server = AuthorizationServer(
|
||||
query_client=create_query_client_func(oauth2_layer),
|
||||
save_token=create_save_token_func(oauth2_layer),
|
||||
)
|
||||
server = AuthorizationServer(persistence_layer=oauth2_layer)
|
||||
server.register_grant(
|
||||
AuthorizationCodeGrant,
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user