venv added, updated
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
from .client_mixin import OAuth2ClientMixin
|
||||
from .tokens_mixins import OAuth2AuthorizationCodeMixin, OAuth2TokenMixin
|
||||
from .functions import (
|
||||
create_query_client_func,
|
||||
create_save_token_func,
|
||||
create_query_token_func,
|
||||
create_revocation_endpoint,
|
||||
create_bearer_token_validator,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuth2ClientMixin', 'OAuth2AuthorizationCodeMixin', 'OAuth2TokenMixin',
|
||||
'create_query_client_func', 'create_save_token_func',
|
||||
'create_query_token_func', 'create_revocation_endpoint',
|
||||
'create_bearer_token_validator',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,138 @@
|
||||
import secrets
|
||||
|
||||
from sqlalchemy import Column, String, Text, Integer
|
||||
from authlib.common.encoding import json_loads, json_dumps
|
||||
from authlib.oauth2.rfc6749 import ClientMixin
|
||||
from authlib.oauth2.rfc6749 import scope_to_list, list_to_scope
|
||||
|
||||
|
||||
class OAuth2ClientMixin(ClientMixin):
|
||||
client_id = Column(String(48), index=True)
|
||||
client_secret = Column(String(120))
|
||||
client_id_issued_at = Column(Integer, nullable=False, default=0)
|
||||
client_secret_expires_at = Column(Integer, nullable=False, default=0)
|
||||
_client_metadata = Column('client_metadata', Text)
|
||||
|
||||
@property
|
||||
def client_info(self):
|
||||
"""Implementation for Client Info in OAuth 2.0 Dynamic Client
|
||||
Registration Protocol via `Section 3.2.1`_.
|
||||
|
||||
.. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
|
||||
"""
|
||||
return dict(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
client_id_issued_at=self.client_id_issued_at,
|
||||
client_secret_expires_at=self.client_secret_expires_at,
|
||||
)
|
||||
|
||||
@property
|
||||
def client_metadata(self):
|
||||
if 'client_metadata' in self.__dict__:
|
||||
return self.__dict__['client_metadata']
|
||||
if self._client_metadata:
|
||||
data = json_loads(self._client_metadata)
|
||||
self.__dict__['client_metadata'] = data
|
||||
return data
|
||||
return {}
|
||||
|
||||
def set_client_metadata(self, value):
|
||||
self._client_metadata = json_dumps(value)
|
||||
if 'client_metadata' in self.__dict__:
|
||||
del self.__dict__['client_metadata']
|
||||
|
||||
@property
|
||||
def redirect_uris(self):
|
||||
return self.client_metadata.get('redirect_uris', [])
|
||||
|
||||
@property
|
||||
def token_endpoint_auth_method(self):
|
||||
return self.client_metadata.get(
|
||||
'token_endpoint_auth_method',
|
||||
'client_secret_basic'
|
||||
)
|
||||
|
||||
@property
|
||||
def grant_types(self):
|
||||
return self.client_metadata.get('grant_types', [])
|
||||
|
||||
@property
|
||||
def response_types(self):
|
||||
return self.client_metadata.get('response_types', [])
|
||||
|
||||
@property
|
||||
def client_name(self):
|
||||
return self.client_metadata.get('client_name')
|
||||
|
||||
@property
|
||||
def client_uri(self):
|
||||
return self.client_metadata.get('client_uri')
|
||||
|
||||
@property
|
||||
def logo_uri(self):
|
||||
return self.client_metadata.get('logo_uri')
|
||||
|
||||
@property
|
||||
def scope(self):
|
||||
return self.client_metadata.get('scope', '')
|
||||
|
||||
@property
|
||||
def contacts(self):
|
||||
return self.client_metadata.get('contacts', [])
|
||||
|
||||
@property
|
||||
def tos_uri(self):
|
||||
return self.client_metadata.get('tos_uri')
|
||||
|
||||
@property
|
||||
def policy_uri(self):
|
||||
return self.client_metadata.get('policy_uri')
|
||||
|
||||
@property
|
||||
def jwks_uri(self):
|
||||
return self.client_metadata.get('jwks_uri')
|
||||
|
||||
@property
|
||||
def jwks(self):
|
||||
return self.client_metadata.get('jwks', [])
|
||||
|
||||
@property
|
||||
def software_id(self):
|
||||
return self.client_metadata.get('software_id')
|
||||
|
||||
@property
|
||||
def software_version(self):
|
||||
return self.client_metadata.get('software_version')
|
||||
|
||||
def get_client_id(self):
|
||||
return self.client_id
|
||||
|
||||
def get_default_redirect_uri(self):
|
||||
if self.redirect_uris:
|
||||
return self.redirect_uris[0]
|
||||
|
||||
def get_allowed_scope(self, scope):
|
||||
if not scope:
|
||||
return ''
|
||||
allowed = set(self.scope.split())
|
||||
scopes = scope_to_list(scope)
|
||||
return list_to_scope([s for s in scopes if s in allowed])
|
||||
|
||||
def check_redirect_uri(self, redirect_uri):
|
||||
return redirect_uri in self.redirect_uris
|
||||
|
||||
def check_client_secret(self, client_secret):
|
||||
return secrets.compare_digest(self.client_secret, client_secret)
|
||||
|
||||
def check_endpoint_auth_method(self, method, endpoint):
|
||||
if endpoint == 'token':
|
||||
return self.token_endpoint_auth_method == method
|
||||
# TODO
|
||||
return True
|
||||
|
||||
def check_response_type(self, response_type):
|
||||
return response_type in self.response_types
|
||||
|
||||
def check_grant_type(self, grant_type):
|
||||
return grant_type in self.grant_types
|
||||
@@ -0,0 +1,101 @@
|
||||
import time
|
||||
|
||||
|
||||
def create_query_client_func(session, client_model):
|
||||
"""Create an ``query_client`` function that can be used in authorization
|
||||
server.
|
||||
|
||||
:param session: SQLAlchemy session
|
||||
:param client_model: Client model class
|
||||
"""
|
||||
def query_client(client_id):
|
||||
q = session.query(client_model)
|
||||
return q.filter_by(client_id=client_id).first()
|
||||
return query_client
|
||||
|
||||
|
||||
def create_save_token_func(session, token_model):
|
||||
"""Create an ``save_token`` function that can be used in authorization
|
||||
server.
|
||||
|
||||
:param session: SQLAlchemy session
|
||||
:param token_model: Token model class
|
||||
"""
|
||||
def save_token(token, request):
|
||||
if request.user:
|
||||
user_id = request.user.get_user_id()
|
||||
else:
|
||||
user_id = None
|
||||
client = request.client
|
||||
item = token_model(
|
||||
client_id=client.client_id,
|
||||
user_id=user_id,
|
||||
**token
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
return save_token
|
||||
|
||||
|
||||
def create_query_token_func(session, token_model):
|
||||
"""Create an ``query_token`` function for revocation, introspection
|
||||
token endpoints.
|
||||
|
||||
:param session: SQLAlchemy session
|
||||
:param token_model: Token model class
|
||||
"""
|
||||
def query_token(token, token_type_hint):
|
||||
q = session.query(token_model)
|
||||
if token_type_hint == 'access_token':
|
||||
return q.filter_by(access_token=token).first()
|
||||
elif token_type_hint == 'refresh_token':
|
||||
return q.filter_by(refresh_token=token).first()
|
||||
# without token_type_hint
|
||||
item = q.filter_by(access_token=token).first()
|
||||
if item:
|
||||
return item
|
||||
return q.filter_by(refresh_token=token).first()
|
||||
return query_token
|
||||
|
||||
|
||||
def create_revocation_endpoint(session, token_model):
|
||||
"""Create a revocation endpoint class with SQLAlchemy session
|
||||
and token model.
|
||||
|
||||
:param session: SQLAlchemy session
|
||||
:param token_model: Token model class
|
||||
"""
|
||||
from authlib.oauth2.rfc7009 import RevocationEndpoint
|
||||
query_token = create_query_token_func(session, token_model)
|
||||
|
||||
class _RevocationEndpoint(RevocationEndpoint):
|
||||
def query_token(self, token, token_type_hint):
|
||||
return query_token(token, token_type_hint)
|
||||
|
||||
def revoke_token(self, token, request):
|
||||
now = int(time.time())
|
||||
hint = request.form.get('token_type_hint')
|
||||
token.access_token_revoked_at = now
|
||||
if hint != 'access_token':
|
||||
token.refresh_token_revoked_at = now
|
||||
session.add(token)
|
||||
session.commit()
|
||||
|
||||
return _RevocationEndpoint
|
||||
|
||||
|
||||
def create_bearer_token_validator(session, token_model):
|
||||
"""Create an bearer token validator class with SQLAlchemy session
|
||||
and token model.
|
||||
|
||||
:param session: SQLAlchemy session
|
||||
:param token_model: Token model class
|
||||
"""
|
||||
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
||||
|
||||
class _BearerTokenValidator(BearerTokenValidator):
|
||||
def authenticate_token(self, token_string):
|
||||
q = session.query(token_model)
|
||||
return q.filter_by(access_token=token_string).first()
|
||||
|
||||
return _BearerTokenValidator
|
||||
@@ -0,0 +1,70 @@
|
||||
import time
|
||||
from sqlalchemy import Column, String, Text, Integer
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
TokenMixin,
|
||||
AuthorizationCodeMixin,
|
||||
)
|
||||
|
||||
|
||||
class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
|
||||
code = Column(String(120), unique=True, nullable=False)
|
||||
client_id = Column(String(48))
|
||||
redirect_uri = Column(Text, default='')
|
||||
response_type = Column(Text, default='')
|
||||
scope = Column(Text, default='')
|
||||
nonce = Column(Text)
|
||||
auth_time = Column(
|
||||
Integer, nullable=False,
|
||||
default=lambda: int(time.time())
|
||||
)
|
||||
|
||||
code_challenge = Column(Text)
|
||||
code_challenge_method = Column(String(48))
|
||||
|
||||
def is_expired(self):
|
||||
return self.auth_time + 300 < time.time()
|
||||
|
||||
def get_redirect_uri(self):
|
||||
return self.redirect_uri
|
||||
|
||||
def get_scope(self):
|
||||
return self.scope
|
||||
|
||||
def get_auth_time(self):
|
||||
return self.auth_time
|
||||
|
||||
def get_nonce(self):
|
||||
return self.nonce
|
||||
|
||||
|
||||
class OAuth2TokenMixin(TokenMixin):
|
||||
client_id = Column(String(48))
|
||||
token_type = Column(String(40))
|
||||
access_token = Column(String(255), unique=True, nullable=False)
|
||||
refresh_token = Column(String(255), index=True)
|
||||
scope = Column(Text, default='')
|
||||
issued_at = Column(
|
||||
Integer, nullable=False, default=lambda: int(time.time())
|
||||
)
|
||||
access_token_revoked_at = Column(Integer, nullable=False, default=0)
|
||||
refresh_token_revoked_at = Column(Integer, nullable=False, default=0)
|
||||
expires_in = Column(Integer, nullable=False, default=0)
|
||||
|
||||
def check_client(self, client):
|
||||
return self.client_id == client.get_client_id()
|
||||
|
||||
def get_scope(self):
|
||||
return self.scope
|
||||
|
||||
def get_expires_in(self):
|
||||
return self.expires_in
|
||||
|
||||
def is_revoked(self):
|
||||
return self.access_token_revoked_at or self.refresh_token_revoked_at
|
||||
|
||||
def is_expired(self):
|
||||
if not self.expires_in:
|
||||
return False
|
||||
|
||||
expires_at = self.issued_at + self.expires_in
|
||||
return expires_at < time.time()
|
||||
Reference in New Issue
Block a user