182 lines
5.6 KiB
Python
182 lines
5.6 KiB
Python
from authlib.oauth2 import OAuth2Request
|
|
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
|
|
from authlib.oauth2.rfc7636 import CodeChallenge
|
|
from authlib.oidc.core import OpenIDCode as OpenIDCode_
|
|
from authlib.oidc.core import UserInfo
|
|
from layercake.dateutils import now, ttl
|
|
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
|
from layercake.funcs import omit, pick
|
|
|
|
from boto3clients import dynamodb_client
|
|
from config import ISSUER, JWT_ALGORITHM, OAUTH2_TABLE
|
|
from integrations.apigateway_oauth2.authorization_server import (
|
|
AuthorizationServer,
|
|
)
|
|
from integrations.apigateway_oauth2.tokens import (
|
|
OAuth2AuthorizationCode,
|
|
OAuth2Token,
|
|
)
|
|
|
|
oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
|
|
|
|
|
|
class OpenIDCode(OpenIDCode_):
|
|
def exists_nonce(self, nonce: str, request: OAuth2Request) -> bool:
|
|
if not request.payload:
|
|
raise ValueError('Missing request payload')
|
|
|
|
nonce_ = oauth2_layer.get_item(
|
|
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
|
|
)
|
|
|
|
return bool(nonce_)
|
|
|
|
def get_jwt_config(self, grant):
|
|
return {
|
|
'key': 'secret-key',
|
|
'alg': JWT_ALGORITHM,
|
|
'iss': ISSUER,
|
|
'exp': 3600,
|
|
}
|
|
|
|
def generate_user_info(self, user: dict, scope: str) -> UserInfo:
|
|
return UserInfo(
|
|
sub=user['id'],
|
|
name=user['name'],
|
|
email=user['email'],
|
|
email_verified=user.get('email_verified', False),
|
|
).filter(scope)
|
|
|
|
|
|
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
|
TOKEN_ENDPOINT_AUTH_METHODS = [
|
|
'client_secret_basic',
|
|
'client_secret_post',
|
|
'none',
|
|
]
|
|
|
|
def save_authorization_code(
|
|
self,
|
|
code: str,
|
|
request: OAuth2Request,
|
|
) -> None:
|
|
if not request.payload:
|
|
raise ValueError('Missing request payload')
|
|
|
|
if not request.user:
|
|
raise ValueError('Missing request user')
|
|
|
|
client_id: str = request.payload.client_id
|
|
data: dict = request.payload.data
|
|
user: dict = request.user
|
|
nonce: str | None = data.get('nonce')
|
|
code_challenge: str | None = data.get('code_challenge')
|
|
code_challenge_method: str | None = data.get('code_challenge_method')
|
|
now_ = now()
|
|
ttl_ = ttl(start_dt=now_, minutes=10)
|
|
|
|
with oauth2_layer.transact_writer() as transact:
|
|
transact.put(
|
|
item={
|
|
'id': 'OAUTH2#CODE',
|
|
'sk': f'CODE#{code}',
|
|
'redirect_uri': request.payload.redirect_uri,
|
|
'response_type': request.payload.response_type,
|
|
'scope': request.payload.scope,
|
|
'client_id': client_id,
|
|
'user_id': user['id'],
|
|
'nonce': nonce,
|
|
'code_challenge': code_challenge,
|
|
'code_challenge_method': code_challenge_method,
|
|
'created_at': now_,
|
|
'ttl': ttl_,
|
|
},
|
|
)
|
|
|
|
if nonce:
|
|
transact.put(
|
|
item={
|
|
'id': 'OAUTH2#CODE',
|
|
'sk': f'NONCE#{nonce}',
|
|
'client_id': client_id,
|
|
'code': code,
|
|
'created_at': now_,
|
|
'ttl': ttl_,
|
|
},
|
|
)
|
|
|
|
def query_authorization_code(
|
|
self,
|
|
code: str,
|
|
client: ClientMixin,
|
|
) -> OAuth2AuthorizationCode:
|
|
auth_code = oauth2_layer.get_item(
|
|
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
|
|
)
|
|
|
|
return OAuth2AuthorizationCode(
|
|
code=code,
|
|
**auth_code,
|
|
)
|
|
|
|
def delete_authorization_code(
|
|
self,
|
|
authorization_code: OAuth2AuthorizationCode,
|
|
) -> None:
|
|
oauth2_layer.delete_item(
|
|
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}'),
|
|
)
|
|
|
|
def authenticate_user(
|
|
self,
|
|
authorization_code: OAuth2AuthorizationCode,
|
|
) -> dict:
|
|
user = oauth2_layer.get_item(
|
|
KeyPair(pk=authorization_code.user_id, sk='0'),
|
|
)
|
|
return pick(('id', 'name', 'email', 'email_verified'), user)
|
|
|
|
|
|
class RefreshTokenGrant(grants.RefreshTokenGrant):
|
|
INCLUDE_NEW_REFRESH_TOKEN = True
|
|
TOKEN_ENDPOINT_AUTH_METHODS = [
|
|
'client_secret_basic',
|
|
'client_secret_post',
|
|
'none',
|
|
]
|
|
|
|
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
|
|
token = oauth2_layer.get_item(
|
|
KeyPair(
|
|
pk='OAUTH2#TOKEN',
|
|
sk=f'REFRESH_TOKEN#{refresh_token}',
|
|
)
|
|
)
|
|
|
|
return OAuth2Token(
|
|
expires_in=int(token['expires_in']),
|
|
issued_at=int(token['issued_at']),
|
|
refresh_token=refresh_token,
|
|
**omit(('expires_in', 'issued_at'), token),
|
|
)
|
|
|
|
def authenticate_user(self, refresh_token: TokenMixin):
|
|
return refresh_token.get_user()
|
|
|
|
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
|
|
if token := getattr(refresh_token, 'refresh_token', None):
|
|
oauth2_layer.delete_item(
|
|
KeyPair(pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}')
|
|
)
|
|
|
|
|
|
server = AuthorizationServer(persistence_layer=oauth2_layer)
|
|
server.register_grant(
|
|
AuthorizationCodeGrant,
|
|
[
|
|
OpenIDCode(require_nonce=False),
|
|
CodeChallenge(required=True),
|
|
],
|
|
)
|
|
server.register_grant(RefreshTokenGrant)
|