import secrets from collections.abc import Collection from authlib.oauth2.rfc6749 import ( ClientMixin, list_to_scope, scope_to_list, ) class OAuth2Client(ClientMixin): def __init__( self, client_id: str, client_secret: str, scope: str, redirect_uris: list, response_types: list, grant_types: list, token_endpoint_auth_method: str = 'client_secret_basic', ) -> None: self.client_id = client_id self.client_secret = client_secret self.scope = scope self.redirect_uris = redirect_uris self.response_types = response_types self.grant_types = grant_types self.token_endpoint_auth_method = token_endpoint_auth_method def get_client_id(self): return self.client_id def get_allowed_scope( self, scope: Collection[str], ) -> str: if not scope: return '' allowed = set(self.scope.split()) scopes = scope_to_list(scope) return list_to_scope([s for s in scopes if s in allowed]) def get_default_redirect_uri(self) -> str: if self.redirect_uris: return self.redirect_uris[0] raise ValueError('Missing redirect_uris') def check_response_type( self, response_type: str, ) -> bool: return response_type in self.response_types def check_redirect_uri( self, redirect_uri: str, ) -> bool: return redirect_uri in self.redirect_uris def check_endpoint_auth_method( self, method: str, endpoint: str, ) -> bool: if endpoint == 'token': return self.token_endpoint_auth_method == method return True def check_grant_type( self, grant_type: str, ) -> bool: return grant_type in self.grant_types def check_client_secret( self, client_secret: str, ) -> bool: return secrets.compare_digest(self.client_secret, client_secret)