wip
This commit is contained in:
@@ -12,7 +12,7 @@ from aws_lambda_powertools.utilities.data_classes.api_gateway_proxy_event import
|
||||
from layercake.dateutils import now, ttl
|
||||
from layercake.dynamodb import DynamoDBPersistenceLayer, KeyPair
|
||||
|
||||
from config import ACCESS_TOKEN_EXP_SECONDS, REFRESH_TOKEN_EXP_SECONDS
|
||||
from config import OAUTH2_REFRESH_TOKEN_EXPIRES_IN
|
||||
|
||||
from .client import OAuth2Client
|
||||
from .requests import APIGatewayJsonRequest, APIGatewayOAuth2Request
|
||||
@@ -21,6 +21,12 @@ DYNAMODB_SORT_KEY = os.getenv('DYNAMODB_SORT_KEY')
|
||||
OAUTH2_SCOPES_SUPPORTED = os.getenv('OAUTH2_SCOPES_SUPPORTED')
|
||||
|
||||
|
||||
GRANT_TYPES_EXPIRES_IN = {
|
||||
'refresh_token': 900,
|
||||
'authorization_code': 900,
|
||||
}
|
||||
|
||||
|
||||
class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -35,11 +41,17 @@ class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
)
|
||||
)
|
||||
|
||||
def expires_in(client, grant_type: str) -> int:
|
||||
return GRANT_TYPES_EXPIRES_IN.get(
|
||||
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
|
||||
)
|
||||
|
||||
self.register_token_generator(
|
||||
'default',
|
||||
BearerTokenGenerator(
|
||||
access_token_generator=create_token_generator(42),
|
||||
refresh_token_generator=create_token_generator(48),
|
||||
expires_generator=expires_in,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -60,6 +72,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
refresh_token = token.get('refresh_token')
|
||||
token_type = token['token_type']
|
||||
scope = token['scope']
|
||||
expires_in = int(token['expires_in'])
|
||||
issued_at = int(now_.timestamp())
|
||||
|
||||
with self._persistence_layer.transact_writer() as transact:
|
||||
@@ -72,9 +85,9 @@ class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
'refresh_token': refresh_token,
|
||||
'scope': scope,
|
||||
'user': request.user,
|
||||
'expires_in': ACCESS_TOKEN_EXP_SECONDS,
|
||||
'expires_in': expires_in,
|
||||
'issued_at': issued_at,
|
||||
'ttl': ttl(start_dt=now_, seconds=ACCESS_TOKEN_EXP_SECONDS),
|
||||
'ttl': ttl(start_dt=now_, seconds=expires_in),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -88,9 +101,11 @@ class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
'access_token': access_token,
|
||||
'scope': scope,
|
||||
'user': request.user,
|
||||
'expires_in': REFRESH_TOKEN_EXP_SECONDS,
|
||||
'expires_in': OAUTH2_REFRESH_TOKEN_EXPIRES_IN,
|
||||
'issued_at': issued_at,
|
||||
'ttl': ttl(start_dt=now_, seconds=REFRESH_TOKEN_EXP_SECONDS),
|
||||
'ttl': ttl(
|
||||
start_dt=now_, seconds=OAUTH2_REFRESH_TOKEN_EXPIRES_IN
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -117,23 +132,35 @@ class AuthorizationServer(oauth2.AuthorizationServer):
|
||||
)
|
||||
|
||||
def create_oauth2_request(
|
||||
self, request: APIGatewayProxyEventV2
|
||||
self,
|
||||
request: APIGatewayProxyEventV2,
|
||||
) -> APIGatewayOAuth2Request:
|
||||
return APIGatewayOAuth2Request(request)
|
||||
|
||||
def create_json_request(
|
||||
self, request: APIGatewayProxyEventV2
|
||||
self,
|
||||
request: APIGatewayProxyEventV2,
|
||||
) -> APIGatewayJsonRequest:
|
||||
return APIGatewayJsonRequest(request)
|
||||
|
||||
def handle_response(self, status: int, body, headers):
|
||||
def handle_response(
|
||||
self,
|
||||
status: int,
|
||||
body,
|
||||
headers,
|
||||
):
|
||||
return Response(
|
||||
status_code=status,
|
||||
body=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def send_signal(self, name: str, *args, **kwargs) -> None:
|
||||
def send_signal(
|
||||
self,
|
||||
name: str,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> None:
|
||||
# after_authenticate_client
|
||||
# when client is authenticated
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import secrets
|
||||
from collections.abc import Collection
|
||||
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
ClientMixin,
|
||||
@@ -29,7 +30,7 @@ class OAuth2Client(ClientMixin):
|
||||
def get_client_id(self):
|
||||
return self.client_id
|
||||
|
||||
def get_allowed_scope(self, scope) -> str:
|
||||
def get_allowed_scope(self, scope: Collection[str]) -> str:
|
||||
if not scope:
|
||||
return ''
|
||||
|
||||
@@ -37,24 +38,26 @@ class OAuth2Client(ClientMixin):
|
||||
scopes = scope_to_list(scope)
|
||||
return list_to_scope([s for s in scopes if s in allowed])
|
||||
|
||||
def get_default_redirect_uri(self) -> str: # type: ignore
|
||||
def get_default_redirect_uri(self) -> str:
|
||||
if self.redirect_uris:
|
||||
return self.redirect_uris[0]
|
||||
|
||||
def check_response_type(self, response_type):
|
||||
raise ValueError('Missing redirect_uris')
|
||||
|
||||
def check_response_type(self, response_type: str) -> bool:
|
||||
return response_type in self.response_types
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
def check_redirect_uri(self, redirect_uri: str) -> bool:
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def check_endpoint_auth_method(self, method, endpoint):
|
||||
def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
|
||||
if endpoint == 'token':
|
||||
return self.token_endpoint_auth_method == method
|
||||
|
||||
return True
|
||||
|
||||
def check_grant_type(self, grant_type):
|
||||
def check_grant_type(self, grant_type: str) -> bool:
|
||||
return grant_type in self.grant_types
|
||||
|
||||
def check_client_secret(self, client_secret):
|
||||
def check_client_secret(self, client_secret: str) -> bool:
|
||||
return secrets.compare_digest(self.client_secret, client_secret)
|
||||
|
||||
@@ -5,7 +5,7 @@ from authlib.oauth2.rfc6749 import (
|
||||
ClientMixin,
|
||||
TokenMixin,
|
||||
)
|
||||
from layercake.dateutils import fromisoformat
|
||||
from layercake.dateutils import fromisoformat, now
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(AuthorizationCodeMixin):
|
||||
@@ -32,25 +32,26 @@ class OAuth2AuthorizationCode(AuthorizationCodeMixin):
|
||||
self.code_challenge_method = code_challenge_method
|
||||
self.nonce = nonce
|
||||
|
||||
auth_time = fromisoformat(kwargs.get('created_at', '')) or now()
|
||||
created_at = kwargs.get('created_at', '')
|
||||
auth_time = fromisoformat(created_at) or now()
|
||||
self.auth_time = int(auth_time.timestamp())
|
||||
|
||||
def get_redirect_uri(self):
|
||||
def get_redirect_uri(self) -> str:
|
||||
return self.redirect_uri
|
||||
|
||||
def get_scope(self):
|
||||
def get_scope(self) -> str:
|
||||
return self.scope
|
||||
|
||||
def get_nonce(self):
|
||||
def get_nonce(self) -> str | None:
|
||||
return self.nonce
|
||||
|
||||
def get_auth_time(self):
|
||||
def get_auth_time(self) -> int:
|
||||
return self.auth_time
|
||||
|
||||
def get_acr(self):
|
||||
return '0'
|
||||
|
||||
def get_amr(self):
|
||||
def get_amr(self) -> list:
|
||||
return []
|
||||
|
||||
|
||||
@@ -77,7 +78,7 @@ class OAuth2Token(TokenMixin):
|
||||
def get_user(self) -> dict:
|
||||
return self.user
|
||||
|
||||
def check_client(self, client: ClientMixin):
|
||||
def check_client(self, client: ClientMixin) -> bool:
|
||||
return self.client_id == client.get_client_id()
|
||||
|
||||
def get_scope(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user