update focus

This commit is contained in:
2025-09-18 00:39:36 -03:00
parent b4c7e191fe
commit 0ebf108a94
4 changed files with 48 additions and 22 deletions

View File

@@ -28,6 +28,12 @@ GRANT_TYPES_EXPIRES_IN = {
logger = Logger(__name__) logger = Logger(__name__)
def expires_in(client, grant_type: str) -> int:
return GRANT_TYPES_EXPIRES_IN.get(
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
)
class AuthorizationServer(oauth2.AuthorizationServer): class AuthorizationServer(oauth2.AuthorizationServer):
def __init__( def __init__(
self, self,
@@ -36,15 +42,9 @@ class AuthorizationServer(oauth2.AuthorizationServer):
) -> None: ) -> None:
self._persistence_layer = persistence_layer self._persistence_layer = persistence_layer
super().__init__( if OAUTH2_SCOPES_SUPPORTED:
scopes_supported=( super().__init__(
set(OAUTH2_SCOPES_SUPPORTED.split()) if OAUTH2_SCOPES_SUPPORTED else [] scopes_supported=set(OAUTH2_SCOPES_SUPPORTED.split()),
)
)
def expires_in(client, grant_type: str) -> int:
return GRANT_TYPES_EXPIRES_IN.get(
grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN
) )
self.register_token_generator( self.register_token_generator(
@@ -61,6 +61,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
token: dict, token: dict,
request: requests.OAuth2Request, request: requests.OAuth2Request,
) -> None: ) -> None:
"""Define function to save the generated token into database."""
if not request.payload: if not request.payload:
raise ValueError('Missing request payload') raise ValueError('Missing request payload')
@@ -88,6 +89,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
item={ item={
'id': user_id, 'id': user_id,
'sk': f'SESSION#ACCESS_TOKEN#{access_token}', 'sk': f'SESSION#ACCESS_TOKEN#{access_token}',
'client_id': client_id,
'ttl': access_token_ttl, 'ttl': access_token_ttl,
'created_at': now_, 'created_at': now_,
} }
@@ -112,6 +114,7 @@ class AuthorizationServer(oauth2.AuthorizationServer):
item={ item={
'id': user_id, 'id': user_id,
'sk': f'SESSION#REFRESH_TOKEN#{refresh_token}', 'sk': f'SESSION#REFRESH_TOKEN#{refresh_token}',
'client_id': client_id,
'ttl': access_token_ttl, 'ttl': access_token_ttl,
'created_at': now_, 'created_at': now_,
} }
@@ -133,7 +136,11 @@ class AuthorizationServer(oauth2.AuthorizationServer):
return None return None
def query_client(self, client_id: str): def query_client(
self,
client_id: str,
):
"""Query OAuth client by client_id."""
client = self._persistence_layer.collection.get_item( client = self._persistence_layer.collection.get_item(
KeyPair( KeyPair(
pk='OAUTH2', pk='OAUTH2',

View File

@@ -30,7 +30,10 @@ class OAuth2Client(ClientMixin):
def get_client_id(self): def get_client_id(self):
return self.client_id return self.client_id
def get_allowed_scope(self, scope: Collection[str]) -> str: def get_allowed_scope(
self,
scope: Collection[str],
) -> str:
if not scope: if not scope:
return '' return ''
@@ -44,20 +47,36 @@ class OAuth2Client(ClientMixin):
raise ValueError('Missing redirect_uris') raise ValueError('Missing redirect_uris')
def check_response_type(self, response_type: str) -> bool: def check_response_type(
self,
response_type: str,
) -> bool:
return response_type in self.response_types return response_type in self.response_types
def check_redirect_uri(self, redirect_uri: str) -> bool: def check_redirect_uri(
self,
redirect_uri: str,
) -> bool:
return redirect_uri in self.redirect_uris return redirect_uri in self.redirect_uris
def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: def check_endpoint_auth_method(
self,
method: str,
endpoint: str,
) -> bool:
if endpoint == 'token': if endpoint == 'token':
return self.token_endpoint_auth_method == method return self.token_endpoint_auth_method == method
return True return True
def check_grant_type(self, grant_type: str) -> bool: def check_grant_type(
self,
grant_type: str,
) -> bool:
return grant_type in self.grant_types return grant_type in self.grant_types
def check_client_secret(self, client_secret: str) -> bool: def check_client_secret(
self,
client_secret: str,
) -> bool:
return secrets.compare_digest(self.client_secret, client_secret) return secrets.compare_digest(self.client_secret, client_secret)

View File

@@ -225,24 +225,24 @@ class RevocationEndpoint(rfc7009.RevocationEndpoint):
token_string: str, token_string: str,
token_type_hint: str | None = None, token_type_hint: str | None = None,
): ):
result = dyn.collection.get_items( t = dyn.collection.get_items(
TransactKey('OAUTH2#TOKEN') TransactKey('OAUTH2#TOKEN')
+ SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token') + SortKey(sk=f'REFRESH_TOKEN#{token_string}', rename_key='refresh_token')
+ SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'), + SortKey(sk=f'ACCESS_TOKEN#{token_string}', rename_key='access_token'),
flatten_top=False, flatten_top=False,
) )
if not result: if not t:
return None return None
logger.debug('Tokens retrieved', result=result) logger.debug('Tokens retrieved', result=t)
if not token_type_hint: if not token_type_hint:
token_type_hint = ( token_type_hint = (
'refresh_token' if 'refresh_token' in result else 'access_token' 'refresh_token' if 'refresh_token' in t else 'access_token'
) )
token = result[token_type_hint] token = t[token_type_hint]
return OAuth2Token( return OAuth2Token(
expires_in=int(token['expires_in']), expires_in=int(token['expires_in']),

View File

@@ -144,7 +144,7 @@ export default function Index({}: Route.ComponentProps) {
<FormItem> <FormItem>
<FormLabel>Email ou CPF</FormLabel> <FormLabel>Email ou CPF</FormLabel>
<FormControl> <FormControl>
<Input {...field} /> <Input autoFocus={true} {...field} />
</FormControl> </FormControl>
<FormMessage /> <FormMessage />
</FormItem> </FormItem>