add revoke
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from authlib.common.urls import add_params_to_uri
|
||||
from authlib.oauth2 import OAuth2Request, rfc9207
|
||||
from authlib.oauth2 import OAuth2Request, rfc7009, rfc9207
|
||||
from authlib.oauth2.rfc6749 import ClientMixin, TokenMixin, grants
|
||||
from authlib.oauth2.rfc7636 import CodeChallenge
|
||||
from authlib.oidc.core import OpenIDCode as OpenIDCode_
|
||||
@@ -8,7 +8,12 @@ from aws_lambda_powertools import Logger
|
||||
from aws_lambda_powertools.event_handler.api_gateway import Response
|
||||
from aws_lambda_powertools.event_handler.exceptions import NotFoundError
|
||||
from layercake.dateutils import now, ttl
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
||||
from layercake.dynamodb import (
|
||||
DynamoDBPersistenceLayer,
|
||||
KeyPair,
|
||||
SortKey,
|
||||
TransactKey,
|
||||
)
|
||||
from layercake.funcs import omit, pick
|
||||
|
||||
from boto3clients import dynamodb_client
|
||||
@@ -22,7 +27,7 @@ from integrations.apigateway_oauth2.tokens import (
|
||||
)
|
||||
|
||||
logger = Logger(__name__)
|
||||
oauth2_layer = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
|
||||
dyn = DynamoDBPersistenceLayer(OAUTH2_TABLE, dynamodb_client)
|
||||
|
||||
|
||||
class OpenIDCode(OpenIDCode_):
|
||||
@@ -30,7 +35,7 @@ class OpenIDCode(OpenIDCode_):
|
||||
if not request.payload:
|
||||
raise ValueError('Missing request payload')
|
||||
|
||||
nonce_ = oauth2_layer.get_item(
|
||||
nonce_ = dyn.get_item(
|
||||
KeyPair(pk='OAUTH2#CODE', sk=f'NONCE#{nonce}'),
|
||||
)
|
||||
|
||||
@@ -65,6 +70,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
code: str,
|
||||
request: OAuth2Request,
|
||||
) -> None:
|
||||
"""Save authorization_code for later use."""
|
||||
if not request.payload:
|
||||
raise ValueError('Missing request payload')
|
||||
|
||||
@@ -81,7 +87,7 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
now_ = now()
|
||||
ttl_ = ttl(start_dt=now_, minutes=10)
|
||||
|
||||
with oauth2_layer.transact_writer() as transact:
|
||||
with dyn.transact_writer() as transact:
|
||||
transact.put(
|
||||
item={
|
||||
'id': 'OAUTH2#CODE',
|
||||
@@ -116,7 +122,8 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
code: str,
|
||||
client: ClientMixin,
|
||||
) -> OAuth2AuthorizationCode:
|
||||
auth_code = oauth2_layer.get_item(
|
||||
"""Get authorization_code from previously savings."""
|
||||
auth_code = dyn.get_item(
|
||||
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{code}'),
|
||||
)
|
||||
|
||||
@@ -129,16 +136,24 @@ class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
|
||||
self,
|
||||
authorization_code: OAuth2AuthorizationCode,
|
||||
) -> None:
|
||||
oauth2_layer.delete_item(
|
||||
KeyPair(pk='OAUTH2#CODE', sk=f'CODE#{authorization_code.code}'),
|
||||
"""Delete authorization code from database or cache."""
|
||||
dyn.delete_item(
|
||||
KeyPair(
|
||||
pk='OAUTH2#CODE',
|
||||
sk=f'CODE#{authorization_code.code}',
|
||||
),
|
||||
)
|
||||
|
||||
def authenticate_user(
|
||||
self,
|
||||
authorization_code: OAuth2AuthorizationCode,
|
||||
) -> dict:
|
||||
user = oauth2_layer.get_item(
|
||||
KeyPair(pk=authorization_code.user_id, sk='0'),
|
||||
"""Authenticate the user related to this authorization_code."""
|
||||
user = dyn.get_item(
|
||||
KeyPair(
|
||||
pk=authorization_code.user_id,
|
||||
sk='0',
|
||||
),
|
||||
)
|
||||
return pick(('id', 'name', 'email', 'email_verified'), user)
|
||||
|
||||
@@ -154,10 +169,13 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||
'client_secret_post',
|
||||
'none',
|
||||
]
|
||||
|
||||
# The authorization server MAY issue a new refresh token
|
||||
INCLUDE_NEW_REFRESH_TOKEN = True
|
||||
|
||||
def authenticate_refresh_token(self, refresh_token: str, **kwargs) -> TokenMixin:
|
||||
token = oauth2_layer.collection.get_item(
|
||||
"""Get token information with refresh_token string."""
|
||||
token = dyn.collection.get_item(
|
||||
KeyPair(
|
||||
pk='OAUTH2#TOKEN',
|
||||
sk=f'REFRESH_TOKEN#{refresh_token}',
|
||||
@@ -175,16 +193,90 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
|
||||
)
|
||||
|
||||
def authenticate_user(self, refresh_token: TokenMixin):
|
||||
"""Authenticate the user related to this credential."""
|
||||
return refresh_token.get_user()
|
||||
|
||||
def revoke_old_credential(self, refresh_token: TokenMixin) -> None:
|
||||
logger.info('Revoking old refresh token', refresh_token=refresh_token)
|
||||
token = getattr(refresh_token, 'refresh_token', None)
|
||||
"""The authorization server MAY revoke the old refresh token after
|
||||
issuing a new refresh token to the client."""
|
||||
|
||||
if token:
|
||||
oauth2_layer.delete_item(
|
||||
KeyPair(pk='OAUTH2#TOKEN', sk=f'REFRESH_TOKEN#{token}')
|
||||
logger.debug('Revoking old refresh token', refresh_token=refresh_token)
|
||||
token = getattr(refresh_token, 'refresh_token', None)
|
||||
user = refresh_token.get_user()
|
||||
|
||||
with dyn.transact_writer() as transact:
|
||||
transact.delete(
|
||||
key=KeyPair(
|
||||
pk='OAUTH2#TOKEN',
|
||||
sk=f'REFRESH_TOKEN#{token}',
|
||||
)
|
||||
)
|
||||
transact.delete(
|
||||
key=KeyPair(
|
||||
pk=user.get('id'),
|
||||
sk=f'SESSION#REFRESH_TOKEN#{token}',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RevocationEndpoint(rfc7009.RevocationEndpoint):
|
||||
def query_token( # type: ignore
|
||||
self,
|
||||
token_string: str,
|
||||
token_type_hint: str | None = None,
|
||||
):
|
||||
result = dyn.collection.get_items(
|
||||
TransactKey('OAUTH2#TOKEN')
|
||||
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
|
||||
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
|
||||
flatten_top=False,
|
||||
)
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
logger.debug('Tokens retrieved', result=result)
|
||||
|
||||
if not token_type_hint:
|
||||
token_type_hint = (
|
||||
'refresh_token' if 'refresh_token' in result else 'access_token'
|
||||
)
|
||||
|
||||
token = result[token_type_hint]
|
||||
|
||||
return OAuth2Token(
|
||||
expires_in=int(token['expires_in']),
|
||||
issued_at=int(token['issued_at']),
|
||||
**{token_type_hint: token_string},
|
||||
**omit(('expires_in', 'issued_at', 'refresh_token', 'access_token'), token),
|
||||
)
|
||||
|
||||
def revoke_token(
|
||||
self,
|
||||
token: OAuth2Token,
|
||||
request: OAuth2Request,
|
||||
):
|
||||
user_id = token.user['id']
|
||||
r = dyn.collection.query(KeyPair(pk=user_id, sk='SESSION'))
|
||||
|
||||
with dyn.transact_writer() as transact:
|
||||
# Revoke all sessions, access tokens, and refresh tokens
|
||||
for x in r['items']:
|
||||
pk, sk = x['id'], x['sk']
|
||||
*_, kind, idx = sk.split('#')
|
||||
|
||||
transact.delete(key=KeyPair(pk, sk))
|
||||
transact.delete(
|
||||
key=KeyPair(
|
||||
pk='SESSION',
|
||||
sk=idx,
|
||||
)
|
||||
if kind == 'SESSION'
|
||||
else KeyPair(
|
||||
pk='OAUTH2#TOKEN',
|
||||
sk=f'{kind}#{idx}',
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class IssuerParameter(rfc9207.IssuerParameter):
|
||||
@@ -207,7 +299,7 @@ class IssuerParameter(rfc9207.IssuerParameter):
|
||||
return ISSUER
|
||||
|
||||
|
||||
server = AuthorizationServer(persistence_layer=oauth2_layer)
|
||||
server = AuthorizationServer(persistence_layer=dyn)
|
||||
server.register_grant(
|
||||
AuthorizationCodeGrant,
|
||||
[
|
||||
@@ -216,4 +308,5 @@ server.register_grant(
|
||||
],
|
||||
)
|
||||
server.register_grant(RefreshTokenGrant)
|
||||
server.register_endpoint(RevocationEndpoint)
|
||||
server.register_extension(IssuerParameter())
|
||||
|
||||
Reference in New Issue
Block a user