venv added, updated

This commit is contained in:
Norbert
2024-09-13 09:46:28 +02:00
parent 577596d9f3
commit 82af8c809a
4812 changed files with 640223 additions and 2 deletions

View File

@@ -0,0 +1,18 @@
from .registry import BaseOAuth
from .sync_app import BaseApp, OAuth1Mixin, OAuth2Mixin
from .sync_openid import OpenIDMixin
from .framework_integration import FrameworkIntegration
from .errors import (
OAuthError, MissingRequestTokenError, MissingTokenError,
TokenExpiredError, InvalidTokenError, UnsupportedTokenTypeError,
MismatchingStateError,
)
__all__ = [
'BaseOAuth',
'BaseApp', 'OAuth1Mixin', 'OAuth2Mixin',
'OpenIDMixin', 'FrameworkIntegration',
'OAuthError', 'MissingRequestTokenError', 'MissingTokenError',
'TokenExpiredError', 'InvalidTokenError', 'UnsupportedTokenTypeError',
'MismatchingStateError',
]

View File

@@ -0,0 +1,144 @@
import time
import logging
from authlib.common.urls import urlparse
from .errors import (
MissingRequestTokenError,
MissingTokenError,
)
from .sync_app import OAuth1Base, OAuth2Base
log = logging.getLogger(__name__)
__all__ = ['AsyncOAuth1Mixin', 'AsyncOAuth2Mixin']
class AsyncOAuth1Mixin(OAuth1Base):
async def request(self, method, url, token=None, **kwargs):
async with self._get_oauth_client() as session:
return await _http_request(self, session, method, url, token, kwargs)
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
if not self.authorize_url:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client() as client:
client.redirect_uri = redirect_uri
params = {}
if self.request_token_params:
params.update(self.request_token_params)
request_token = await client.fetch_request_token(self.request_token_url, **params)
log.debug(f'Fetch request token: {request_token!r}')
url = client.create_authorization_url(self.authorize_url, **kwargs)
state = request_token['oauth_token']
return {'url': url, 'request_token': request_token, 'state': state}
async def fetch_access_token(self, request_token=None, **kwargs):
"""Fetch access token in one step.
:param request_token: A previous request token for OAuth 1.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
async with self._get_oauth_client() as client:
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(kwargs)
client.token = token
params = self.access_token_params or {}
token = await client.fetch_access_token(self.access_token_url, **params)
return token
class AsyncOAuth2Mixin(OAuth2Base):
async def _on_update_token(self, token, refresh_token=None, access_token=None):
if self._update_token:
await self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
async def load_server_metadata(self):
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
async with self.client_cls(**self.client_kwargs) as client:
resp = await client.request('GET', self._server_metadata_url, withhold_token=True)
resp.raise_for_status()
metadata = resp.json()
metadata['_loaded_at'] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
async def request(self, method, url, token=None, **kwargs):
metadata = await self.load_server_metadata()
async with self._get_oauth_client(**metadata) as session:
return await _http_request(self, session, method, url, token, kwargs)
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = await self.load_server_metadata()
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client(**metadata) as client:
client.redirect_uri = redirect_uri
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs)
async def fetch_access_token(self, redirect_uri=None, **kwargs):
"""Fetch access token in the final step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = await self.load_server_metadata()
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
async with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
params = {}
if self.access_token_params:
params.update(self.access_token_params)
params.update(kwargs)
token = await client.fetch_token(token_endpoint, **params)
return token
async def _http_request(ctx, session, method, url, token, kwargs):
request = kwargs.pop('request', None)
withhold_token = kwargs.get('withhold_token')
if ctx.api_base_url and not url.startswith(('https://', 'http://')):
url = urlparse.urljoin(ctx.api_base_url, url)
if withhold_token:
return await session.request(method, url, **kwargs)
if token is None and ctx._fetch_token and request:
token = await ctx._fetch_token(request)
if token is None:
raise MissingTokenError()
session.token = token
return await session.request(method, url, **kwargs)

View File

@@ -0,0 +1,79 @@
from authlib.jose import JsonWebToken, JsonWebKey
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
__all__ = ['AsyncOpenIDMixin']
class AsyncOpenIDMixin:
async def fetch_jwk_set(self, force=False):
metadata = await self.load_server_metadata()
jwk_set = metadata.get('jwks')
if jwk_set and not force:
return jwk_set
uri = metadata.get('jwks_uri')
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
async with self.client_cls(**self.client_kwargs) as client:
resp = await client.request('GET', uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()
self.server_metadata['jwks'] = jwk_set
return jwk_set
async def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = await self.load_server_metadata()
resp = await self.get(metadata['userinfo_endpoint'], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
async def parse_id_token(self, token, nonce, claims_options=None):
"""Return an instance of UserInfo from token's ``id_token``."""
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if 'access_token' in token:
claims_params['access_token'] = token['access_token']
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = await self.load_server_metadata()
if claims_options is None and 'issuer' in metadata:
claims_options = {'iss': {'values': [metadata['issuer']]}}
alg_values = metadata.get('id_token_signing_alg_values_supported')
if not alg_values:
alg_values = ['RS256']
jwt = JsonWebToken(alg_values)
jwk_set = await self.fetch_jwk_set()
try:
claims = jwt.decode(
token['id_token'],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
jwk_set = await self.fetch_jwk_set(force=True)
claims = jwt.decode(
token['id_token'],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
# https://github.com/lepture/authlib/issues/259
if claims.get('nonce_supported') is False:
claims.params['nonce'] = None
claims.validate(leeway=120)
return UserInfo(claims)

View File

@@ -0,0 +1,30 @@
from authlib.common.errors import AuthlibBaseError
class OAuthError(AuthlibBaseError):
error = 'oauth_error'
class MissingRequestTokenError(OAuthError):
error = 'missing_request_token'
class MissingTokenError(OAuthError):
error = 'missing_token'
class TokenExpiredError(OAuthError):
error = 'token_expired'
class InvalidTokenError(OAuthError):
error = 'token_invalid'
class UnsupportedTokenTypeError(OAuthError):
error = 'unsupported_token_type'
class MismatchingStateError(OAuthError):
error = 'mismatching_state'
description = 'CSRF Warning! State not equal in request and response.'

View File

@@ -0,0 +1,64 @@
import json
import time
class FrameworkIntegration:
expires_in = 3600
def __init__(self, name, cache=None):
self.name = name
self.cache = cache
def _get_cache_data(self, key):
value = self.cache.get(key)
if not value:
return None
try:
return json.loads(value)
except (TypeError, ValueError):
return None
def _clear_session_state(self, session):
now = time.time()
for key in dict(session):
if '_authlib_' in key:
# TODO: remove in future
session.pop(key)
elif key.startswith('_state_'):
value = session[key]
exp = value.get('exp')
if not exp or exp < now:
session.pop(key)
def get_state_data(self, session, state):
key = f'_state_{self.name}_{state}'
if self.cache:
value = self._get_cache_data(key)
else:
value = session.get(key)
if value:
return value.get('data')
return None
def set_state_data(self, session, state, data):
key = f'_state_{self.name}_{state}'
if self.cache:
self.cache.set(key, json.dumps({'data': data}), self.expires_in)
else:
now = time.time()
session[key] = {'data': data, 'exp': now + self.expires_in}
def clear_state_data(self, session, state):
key = f'_state_{self.name}_{state}'
if self.cache:
self.cache.delete(key)
else:
session.pop(key, None)
self._clear_session_state(session)
def update_token(self, token, refresh_token=None, access_token=None):
raise NotImplementedError()
@staticmethod
def load_config(oauth, name, params):
raise NotImplementedError()

View File

@@ -0,0 +1,131 @@
import functools
from .framework_integration import FrameworkIntegration
__all__ = ['BaseOAuth']
OAUTH_CLIENT_PARAMS = (
'client_id', 'client_secret',
'request_token_url', 'request_token_params',
'access_token_url', 'access_token_params',
'refresh_token_url', 'refresh_token_params',
'authorize_url', 'authorize_params',
'api_base_url', 'client_kwargs',
'server_metadata_url',
)
class BaseOAuth:
"""Registry for oauth clients.
Create an instance for registry::
oauth = OAuth()
"""
oauth1_client_cls = None
oauth2_client_cls = None
framework_integration_cls = FrameworkIntegration
def __init__(self, cache=None, fetch_token=None, update_token=None):
self._registry = {}
self._clients = {}
self.cache = cache
self.fetch_token = fetch_token
self.update_token = update_token
def create_client(self, name):
"""Create or get the given named OAuth client. For instance, the
OAuth registry has ``.register`` a twitter client, developers may
access the client with::
client = oauth.create_client('twitter')
:param: name: Name of the remote application
:return: OAuth remote app
"""
if name in self._clients:
return self._clients[name]
if name not in self._registry:
return None
overwrite, config = self._registry[name]
client_cls = config.pop('client_cls', None)
if client_cls and client_cls.OAUTH_APP_CONFIG:
kwargs = client_cls.OAUTH_APP_CONFIG
kwargs.update(config)
else:
kwargs = config
kwargs = self.generate_client_kwargs(name, overwrite, **kwargs)
framework = self.framework_integration_cls(name, self.cache)
if client_cls:
client = client_cls(framework, name, **kwargs)
elif kwargs.get('request_token_url'):
client = self.oauth1_client_cls(framework, name, **kwargs)
else:
client = self.oauth2_client_cls(framework, name, **kwargs)
self._clients[name] = client
return client
def register(self, name, overwrite=False, **kwargs):
"""Registers a new remote application.
:param name: Name of the remote application.
:param overwrite: Overwrite existing config with framework settings.
:param kwargs: Parameters for :class:`RemoteApp`.
Find parameters for the given remote app class. When a remote app is
registered, it can be accessed with *named* attribute::
oauth.register('twitter', client_id='', ...)
oauth.twitter.get('timeline')
"""
self._registry[name] = (overwrite, kwargs)
return self.create_client(name)
def generate_client_kwargs(self, name, overwrite, **kwargs):
fetch_token = kwargs.pop('fetch_token', None)
update_token = kwargs.pop('update_token', None)
config = self.load_config(name, OAUTH_CLIENT_PARAMS)
if config:
kwargs = _config_client(config, kwargs, overwrite)
if not fetch_token and self.fetch_token:
fetch_token = functools.partial(self.fetch_token, name)
kwargs['fetch_token'] = fetch_token
if not kwargs.get('request_token_url'):
if not update_token and self.update_token:
update_token = functools.partial(self.update_token, name)
kwargs['update_token'] = update_token
return kwargs
def load_config(self, name, params):
return self.framework_integration_cls.load_config(self, name, params)
def __getattr__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
if key in self._registry:
return self.create_client(key)
raise AttributeError('No such client: %s' % key)
def _config_client(config, kwargs, overwrite):
for k in OAUTH_CLIENT_PARAMS:
v = config.get(k, None)
if k not in kwargs:
kwargs[k] = v
elif overwrite and v:
if isinstance(kwargs[k], dict):
kwargs[k].update(v)
else:
kwargs[k] = v
return kwargs

View File

@@ -0,0 +1,348 @@
import time
import logging
from authlib.common.urls import urlparse
from authlib.consts import default_user_agent
from authlib.common.security import generate_token
from .errors import (
MismatchingStateError,
MissingRequestTokenError,
MissingTokenError,
)
log = logging.getLogger(__name__)
class BaseApp:
client_cls = None
OAUTH_APP_CONFIG = None
def request(self, method, url, token=None, **kwargs):
raise NotImplementedError()
def get(self, url, **kwargs):
"""Invoke GET http request.
If ``api_base_url`` configured, shortcut is available::
client.get('users/lepture')
"""
return self.request('GET', url, **kwargs)
def post(self, url, **kwargs):
"""Invoke POST http request.
If ``api_base_url`` configured, shortcut is available::
client.post('timeline', json={'text': 'Hi'})
"""
return self.request('POST', url, **kwargs)
def patch(self, url, **kwargs):
"""Invoke PATCH http request.
If ``api_base_url`` configured, shortcut is available::
client.patch('profile', json={'name': 'Hsiaoming Yang'})
"""
return self.request('PATCH', url, **kwargs)
def put(self, url, **kwargs):
"""Invoke PUT http request.
If ``api_base_url`` configured, shortcut is available::
client.put('profile', json={'name': 'Hsiaoming Yang'})
"""
return self.request('PUT', url, **kwargs)
def delete(self, url, **kwargs):
"""Invoke DELETE http request.
If ``api_base_url`` configured, shortcut is available::
client.delete('posts/123')
"""
return self.request('DELETE', url, **kwargs)
class _RequestMixin:
def _get_requested_token(self, request):
if self._fetch_token and request:
return self._fetch_token(request)
def _send_token_request(self, session, method, url, token, kwargs):
request = kwargs.pop('request', None)
withhold_token = kwargs.get('withhold_token')
if self.api_base_url and not url.startswith(('https://', 'http://')):
url = urlparse.urljoin(self.api_base_url, url)
if withhold_token:
return session.request(method, url, **kwargs)
if token is None:
token = self._get_requested_token(request)
if token is None:
raise MissingTokenError()
session.token = token
return session.request(method, url, **kwargs)
class OAuth1Base:
client_cls = None
def __init__(
self, framework, name=None, fetch_token=None,
client_id=None, client_secret=None,
request_token_url=None, request_token_params=None,
access_token_url=None, access_token_params=None,
authorize_url=None, authorize_params=None,
api_base_url=None, client_kwargs=None, user_agent=None, **kwargs):
self.framework = framework
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.request_token_url = request_token_url
self.request_token_params = request_token_params
self.access_token_url = access_token_url
self.access_token_params = access_token_params
self.authorize_url = authorize_url
self.authorize_params = authorize_params
self.api_base_url = api_base_url
self.client_kwargs = client_kwargs or {}
self._fetch_token = fetch_token
self._user_agent = user_agent or default_user_agent
self._kwargs = kwargs
def _get_oauth_client(self):
session = self.client_cls(self.client_id, self.client_secret, **self.client_kwargs)
session.headers['User-Agent'] = self._user_agent
return session
class OAuth1Mixin(_RequestMixin, OAuth1Base):
def request(self, method, url, token=None, **kwargs):
with self._get_oauth_client() as session:
return self._send_token_request(session, method, url, token, kwargs)
def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
if not self.authorize_url:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
with self._get_oauth_client() as client:
client.redirect_uri = redirect_uri
params = self.request_token_params or {}
request_token = client.fetch_request_token(self.request_token_url, **params)
log.debug(f'Fetch request token: {request_token!r}')
url = client.create_authorization_url(self.authorize_url, **kwargs)
state = request_token['oauth_token']
return {'url': url, 'request_token': request_token, 'state': state}
def fetch_access_token(self, request_token=None, **kwargs):
"""Fetch access token in one step.
:param request_token: A previous request token for OAuth 1.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
with self._get_oauth_client() as client:
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(kwargs)
client.token = token
params = self.access_token_params or {}
token = client.fetch_access_token(self.access_token_url, **params)
return token
class OAuth2Base:
client_cls = None
def __init__(
self, framework, name=None, fetch_token=None, update_token=None,
client_id=None, client_secret=None,
access_token_url=None, access_token_params=None,
authorize_url=None, authorize_params=None,
api_base_url=None, client_kwargs=None, server_metadata_url=None,
compliance_fix=None, client_auth_methods=None, user_agent=None, **kwargs):
self.framework = framework
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.access_token_url = access_token_url
self.access_token_params = access_token_params
self.authorize_url = authorize_url
self.authorize_params = authorize_params
self.api_base_url = api_base_url
self.client_kwargs = client_kwargs or {}
self.compliance_fix = compliance_fix
self.client_auth_methods = client_auth_methods
self._fetch_token = fetch_token
self._update_token = update_token
self._user_agent = user_agent or default_user_agent
self._server_metadata_url = server_metadata_url
self.server_metadata = kwargs
def _on_update_token(self, token, refresh_token=None, access_token=None):
raise NotImplementedError()
def _get_oauth_client(self, **metadata):
client_kwargs = {}
client_kwargs.update(self.client_kwargs)
client_kwargs.update(metadata)
if self.authorize_url:
client_kwargs['authorization_endpoint'] = self.authorize_url
if self.access_token_url:
client_kwargs['token_endpoint'] = self.access_token_url
session = self.client_cls(
client_id=self.client_id,
client_secret=self.client_secret,
update_token=self._on_update_token,
**client_kwargs
)
if self.client_auth_methods:
for f in self.client_auth_methods:
session.register_client_auth_method(f)
if self.compliance_fix:
self.compliance_fix(session)
session.headers['User-Agent'] = self._user_agent
return session
@staticmethod
def _format_state_params(state_data, params):
if state_data is None:
raise MismatchingStateError()
code_verifier = state_data.get('code_verifier')
if code_verifier:
params['code_verifier'] = code_verifier
redirect_uri = state_data.get('redirect_uri')
if redirect_uri:
params['redirect_uri'] = redirect_uri
return params
@staticmethod
def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs):
rv = {}
if client.code_challenge_method:
code_verifier = kwargs.get('code_verifier')
if not code_verifier:
code_verifier = generate_token(48)
kwargs['code_verifier'] = code_verifier
rv['code_verifier'] = code_verifier
log.debug(f'Using code_verifier: {code_verifier!r}')
scope = kwargs.get('scope', client.scope)
scope = (
(scope if isinstance(scope, (list, tuple)) else scope.split())
if scope
else None
)
if scope and "openid" in scope:
# this is an OpenID Connect service
nonce = kwargs.get('nonce')
if not nonce:
nonce = generate_token(20)
kwargs['nonce'] = nonce
rv['nonce'] = nonce
url, state = client.create_authorization_url(
authorization_endpoint, **kwargs)
rv['url'] = url
rv['state'] = state
return rv
class OAuth2Mixin(_RequestMixin, OAuth2Base):
def _on_update_token(self, token, refresh_token=None, access_token=None):
if callable(self._update_token):
self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
self.framework.update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
def request(self, method, url, token=None, **kwargs):
metadata = self.load_server_metadata()
with self._get_oauth_client(**metadata) as session:
return self._send_token_request(session, method, url, token, kwargs)
def load_server_metadata(self):
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
with self.client_cls(**self.client_kwargs) as session:
resp = session.request('GET', self._server_metadata_url, withhold_token=True)
resp.raise_for_status()
metadata = resp.json()
metadata['_loaded_at'] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = self.load_server_metadata()
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs)
def fetch_access_token(self, redirect_uri=None, **kwargs):
"""Fetch access token in the final step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = self.load_server_metadata()
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
params = {}
if self.access_token_params:
params.update(self.access_token_params)
params.update(kwargs)
token = client.fetch_token(token_endpoint, **params)
return token

View File

@@ -0,0 +1,82 @@
from authlib.jose import jwt, JsonWebToken, JsonWebKey
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
class OpenIDMixin:
def fetch_jwk_set(self, force=False):
metadata = self.load_server_metadata()
jwk_set = metadata.get('jwks')
if jwk_set and not force:
return jwk_set
uri = metadata.get('jwks_uri')
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
with self.client_cls(**self.client_kwargs) as session:
resp = session.request('GET', uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()
self.server_metadata['jwks'] = jwk_set
return jwk_set
def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = self.load_server_metadata()
resp = self.get(metadata['userinfo_endpoint'], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
"""Return an instance of UserInfo from token's ``id_token``."""
if 'id_token' not in token:
return None
load_key = self.create_load_key()
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if 'access_token' in token:
claims_params['access_token'] = token['access_token']
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = self.load_server_metadata()
if claims_options is None and 'issuer' in metadata:
claims_options = {'iss': {'values': [metadata['issuer']]}}
alg_values = metadata.get('id_token_signing_alg_values_supported')
if alg_values:
_jwt = JsonWebToken(alg_values)
else:
_jwt = jwt
claims = _jwt.decode(
token['id_token'], key=load_key,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
# https://github.com/lepture/authlib/issues/259
if claims.get('nonce_supported') is False:
claims.params['nonce'] = None
claims.validate(leeway=leeway)
return UserInfo(claims)
def create_load_key(self):
def load_key(header, _):
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
try:
return jwk_set.find_by_kid(header.get('kid'))
except ValueError:
# re-try with new jwk set
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
return jwk_set.find_by_kid(header.get('kid'))
return load_key

View File

@@ -0,0 +1,19 @@
# flake8: noqa
from .integration import DjangoIntegration, token_update
from .apps import DjangoOAuth1App, DjangoOAuth2App
from ..base_client import BaseOAuth, OAuthError
class OAuth(BaseOAuth):
oauth1_client_cls = DjangoOAuth1App
oauth2_client_cls = DjangoOAuth2App
framework_integration_cls = DjangoIntegration
__all__ = [
'OAuth',
'DjangoOAuth1App', 'DjangoOAuth2App',
'DjangoIntegration',
'token_update', 'OAuthError',
]

View File

@@ -0,0 +1,87 @@
from django.http import HttpResponseRedirect
from ..requests_client import OAuth1Session, OAuth2Session
from ..base_client import (
BaseApp, OAuthError,
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
)
class DjangoAppMixin:
def save_authorize_data(self, request, **kwargs):
state = kwargs.pop('state', None)
if state:
self.framework.set_state_data(request.session, state, kwargs)
else:
raise RuntimeError('Missing state value')
def authorize_redirect(self, request, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param request: HTTP request instance from Django view.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
rv = self.create_authorization_url(redirect_uri, **kwargs)
self.save_authorize_data(request, redirect_uri=redirect_uri, **rv)
return HttpResponseRedirect(rv['url'])
class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp):
client_cls = OAuth1Session
def authorize_access_token(self, request, **kwargs):
"""Fetch access token in one step.
:param request: HTTP request instance from Django view.
:return: A token dict.
"""
params = request.GET.dict()
state = params.get('oauth_token')
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = self.framework.get_state_data(request.session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params['request_token'] = data['request_token']
params.update(kwargs)
self.framework.clear_state_data(request.session, state)
return self.fetch_access_token(**params)
class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
client_cls = OAuth2Session
def authorize_access_token(self, request, **kwargs):
"""Fetch access token in one step.
:param request: HTTP request instance from Django view.
:return: A token dict.
"""
if request.method == 'GET':
error = request.GET.get('error')
if error:
description = request.GET.get('error_description')
raise OAuthError(error=error, description=description)
params = {
'code': request.GET.get('code'),
'state': request.GET.get('state'),
}
else:
params = {
'code': request.POST.get('code'),
'state': request.POST.get('state'),
}
claims_options = kwargs.pop('claims_options', None)
state_data = self.framework.get_state_data(request.session, params.get('state'))
self.framework.clear_state_data(request.session, params.get('state'))
params = self._format_state_params(state_data, params)
token = self.fetch_access_token(**params, **kwargs)
if 'id_token' in token and 'nonce' in state_data:
userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
token['userinfo'] = userinfo
return token

View File

@@ -0,0 +1,22 @@
from django.conf import settings
from django.dispatch import Signal
from ..base_client import FrameworkIntegration
token_update = Signal()
class DjangoIntegration(FrameworkIntegration):
def update_token(self, token, refresh_token=None, access_token=None):
token_update.send(
sender=self.__class__,
name=self.name,
token=token,
refresh_token=refresh_token,
access_token=access_token,
)
@staticmethod
def load_config(oauth, name, params):
config = getattr(settings, 'AUTHLIB_OAUTH_CLIENTS', None)
if config:
return config.get(name)

View File

@@ -0,0 +1,9 @@
# flake8: noqa
from .authorization_server import (
BaseServer, CacheAuthorizationServer
)
from .resource_protector import ResourceProtector
__all__ = ['BaseServer', 'CacheAuthorizationServer', 'ResourceProtector']

View File

@@ -0,0 +1,125 @@
import logging
from authlib.oauth1 import (
OAuth1Request,
AuthorizationServer as _AuthorizationServer,
)
from authlib.oauth1 import TemporaryCredential
from authlib.common.security import generate_token
from authlib.common.urls import url_encode
from django.core.cache import cache
from django.conf import settings
from django.http import HttpResponse
from .nonce import exists_nonce_in_cache
log = logging.getLogger(__name__)
class BaseServer(_AuthorizationServer):
def __init__(self, client_model, token_model, token_generator=None):
self.client_model = client_model
self.token_model = token_model
if token_generator is None:
def token_generator():
return {
'oauth_token': generate_token(42),
'oauth_token_secret': generate_token(48)
}
self.token_generator = token_generator
self._config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {})
self._nonce_expires_in = self._config.get('nonce_expires_in', 86400)
methods = self._config.get('signature_methods')
if methods:
self.SUPPORTED_SIGNATURE_METHODS = methods
def get_client_by_id(self, client_id):
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def exists_nonce(self, nonce, request):
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
def create_token_credential(self, request):
temporary_credential = request.credential
token = self.token_generator()
item = self.token_model(
oauth_token=token['oauth_token'],
oauth_token_secret=token['oauth_token_secret'],
user_id=temporary_credential.get_user_id(),
client_id=temporary_credential.get_client_id()
)
item.save()
return item
def check_authorization_request(self, request):
req = self.create_oauth1_request(request)
self.validate_authorization_request(req)
return req
def create_oauth1_request(self, request):
if request.method == 'POST':
body = request.POST.dict()
else:
body = None
url = request.build_absolute_uri()
return OAuth1Request(request.method, url, body, request.headers)
def handle_response(self, status_code, payload, headers):
resp = HttpResponse(url_encode(payload), status=status_code)
for k, v in headers:
resp[k] = v
return resp
class CacheAuthorizationServer(BaseServer):
def __init__(self, client_model, token_model, token_generator=None):
super().__init__(
client_model, token_model, token_generator)
self._temporary_expires_in = self._config.get(
'temporary_credential_expires_in', 86400)
self._temporary_credential_key_prefix = self._config.get(
'temporary_credential_key_prefix', 'temporary_credential:')
def create_temporary_credential(self, request):
key_prefix = self._temporary_credential_key_prefix
token = self.token_generator()
client_id = request.client_id
redirect_uri = request.redirect_uri
key = key_prefix + token['oauth_token']
token['client_id'] = client_id
if redirect_uri:
token['oauth_callback'] = redirect_uri
cache.set(key, token, timeout=self._temporary_expires_in)
return TemporaryCredential(token)
def get_temporary_credential(self, request):
if not request.token:
return None
key_prefix = self._temporary_credential_key_prefix
key = key_prefix + request.token
value = cache.get(key)
if value:
return TemporaryCredential(value)
def delete_temporary_credential(self, request):
if request.token:
key_prefix = self._temporary_credential_key_prefix
key = key_prefix + request.token
cache.delete(key)
def create_authorization_verifier(self, request):
key_prefix = self._temporary_credential_key_prefix
verifier = generate_token(36)
credential = request.credential
user = request.user
key = key_prefix + credential.get_oauth_token()
credential['oauth_verifier'] = verifier
credential['user_id'] = user.pk
cache.set(key, credential, timeout=self._temporary_expires_in)
return verifier

View File

@@ -0,0 +1,15 @@
from django.core.cache import cache
def exists_nonce_in_cache(nonce, request, timeout):
key_prefix = 'nonce:'
timestamp = request.timestamp
client_id = request.client_id
token = request.token
key = f'{key_prefix}{nonce}-{timestamp}-{client_id}'
if token:
key = f'{key}-{token}'
rv = bool(cache.get(key))
cache.set(key, 1, timeout=timeout)
return rv

View File

@@ -0,0 +1,64 @@
import functools
from authlib.oauth1.errors import OAuth1Error
from authlib.oauth1 import ResourceProtector as _ResourceProtector
from django.http import JsonResponse
from django.conf import settings
from .nonce import exists_nonce_in_cache
class ResourceProtector(_ResourceProtector):
def __init__(self, client_model, token_model):
self.client_model = client_model
self.token_model = token_model
config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {})
methods = config.get('signature_methods', [])
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self._nonce_expires_in = config.get('nonce_expires_in', 86400)
def get_client_by_id(self, client_id):
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def get_token_credential(self, request):
try:
return self.token_model.objects.get(
client_id=request.client_id,
oauth_token=request.token
)
except self.token_model.DoesNotExist:
return None
def exists_nonce(self, nonce, request):
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
def acquire_credential(self, request):
if request.method in ['POST', 'PUT']:
body = request.POST.dict()
else:
body = None
url = request.build_absolute_uri()
req = self.validate_request(request.method, url, body, request.headers)
return req.credential
def __call__(self, realm=None):
def wrapper(f):
@functools.wraps(f)
def decorated(request, *args, **kwargs):
try:
credential = self.acquire_credential(request)
request.oauth1_credential = credential
except OAuth1Error as error:
body = dict(error.get_body())
resp = JsonResponse(body, status=error.status_code)
resp['Cache-Control'] = 'no-store'
resp['Pragma'] = 'no-cache'
return resp
return f(request, *args, **kwargs)
return decorated
return wrapper

View File

@@ -0,0 +1,10 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import ResourceProtector, BearerTokenValidator
from .endpoints import RevocationEndpoint
from .signals import (
client_authenticated,
token_authenticated,
token_revoked
)

View File

@@ -0,0 +1,118 @@
from django.http import HttpResponse
from django.utils.module_loading import import_string
from django.conf import settings
from authlib.oauth2 import (
AuthorizationServer as _AuthorizationServer,
)
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from authlib.common.security import generate_token as _generate_token
from authlib.common.encoding import json_dumps
from .requests import DjangoOAuth2Request, DjangoJsonRequest
from .signals import client_authenticated, token_revoked
class AuthorizationServer(_AuthorizationServer):
"""Django implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
Initialize it with client model and token model::
from authlib.integrations.django_oauth2 import AuthorizationServer
from your_project.models import OAuth2Client, OAuth2Token
server = AuthorizationServer(OAuth2Client, OAuth2Token)
"""
def __init__(self, client_model, token_model):
self.config = getattr(settings, 'AUTHLIB_OAUTH2_PROVIDER', {})
self.client_model = client_model
self.token_model = token_model
scopes_supported = self.config.get('scopes_supported')
super().__init__(scopes_supported=scopes_supported)
# add default token generator
self.register_token_generator('default', self.create_bearer_token_generator())
def query_client(self, client_id):
"""Default method for ``AuthorizationServer.query_client``. Developers MAY
rewrite this function to meet their own needs.
"""
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def save_token(self, token, request):
"""Default method for ``AuthorizationServer.save_token``. Developers MAY
rewrite this function to meet their own needs.
"""
client = request.client
if request.user:
user_id = request.user.pk
else:
user_id = client.user_id
item = self.token_model(
client_id=client.client_id,
user_id=user_id,
**token
)
item.save()
return item
def create_oauth2_request(self, request):
return DjangoOAuth2Request(request)
def create_json_request(self, request):
return DjangoJsonRequest(request)
def handle_response(self, status_code, payload, headers):
if isinstance(payload, dict):
payload = json_dumps(payload)
resp = HttpResponse(payload, status=status_code)
for k, v in headers:
resp[k] = v
return resp
def send_signal(self, name, *args, **kwargs):
if name == 'after_authenticate_client':
client_authenticated.send(sender=self.__class__, *args, **kwargs)
elif name == 'after_revoke_token':
token_revoked.send(sender=self.__class__, *args, **kwargs)
def create_bearer_token_generator(self):
"""Default method to create BearerToken generator."""
conf = self.config.get('access_token_generator', True)
access_token_generator = create_token_generator(conf, 42)
conf = self.config.get('refresh_token_generator', False)
refresh_token_generator = create_token_generator(conf, 48)
conf = self.config.get('token_expires_in')
expires_generator = create_token_expires_in_generator(conf)
return BearerTokenGenerator(
access_token_generator=access_token_generator,
refresh_token_generator=refresh_token_generator,
expires_generator=expires_generator,
)
def create_token_generator(token_generator_conf, length=42):
if callable(token_generator_conf):
return token_generator_conf
if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)
elif token_generator_conf is True:
def token_generator(*args, **kwargs):
return _generate_token(length)
return token_generator
def create_token_expires_in_generator(expires_in_conf=None):
data = {}
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
if expires_in_conf:
data.update(expires_in_conf)
def expires_in(client, grant_type):
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
return expires_in

View File

@@ -0,0 +1,56 @@
from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint
class RevocationEndpoint(_RevocationEndpoint):
"""The revocation endpoint for OAuth authorization servers allows clients
to notify the authorization server that a previously obtained refresh or
access token is no longer needed.
Register it into authorization server, and create token endpoint response
for token revocation::
from django.views.decorators.http import require_http_methods
# see register into authorization server instance
server.register_endpoint(RevocationEndpoint)
@require_http_methods(["POST"])
def revoke_token(request):
return server.create_endpoint_response(
RevocationEndpoint.ENDPOINT_NAME,
request
)
"""
def query_token(self, token, token_type_hint):
"""Query requested token from database."""
token_model = self.server.token_model
if token_type_hint == 'access_token':
rv = _query_access_token(token_model, token)
elif token_type_hint == 'refresh_token':
rv = _query_refresh_token(token_model, token)
else:
rv = _query_access_token(token_model, token)
if not rv:
rv = _query_refresh_token(token_model, token)
return rv
def revoke_token(self, token, request):
"""Mark the give token as revoked."""
token.revoked = True
token.save()
def _query_access_token(token_model, token):
try:
return token_model.objects.get(access_token=token)
except token_model.DoesNotExist:
return None
def _query_refresh_token(token_model, token):
try:
return token_model.objects.get(refresh_token=token)
except token_model.DoesNotExist:
return None

View File

@@ -0,0 +1,46 @@
from collections import defaultdict
from django.http import HttpRequest
from django.utils.functional import cached_property
from authlib.common.encoding import json_loads
from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest
class DjangoOAuth2Request(OAuth2Request):
def __init__(self, request: HttpRequest):
super().__init__(request.method, request.build_absolute_uri(), None, request.headers)
self._request = request
@property
def args(self):
return self._request.GET
@property
def form(self):
return self._request.POST
@cached_property
def data(self):
data = {}
data.update(self._request.GET.dict())
data.update(self._request.POST.dict())
return data
@cached_property
def datalist(self):
values = defaultdict(list)
for k in self.args:
values[k].extend(self.args.getlist(k))
for k in self.form:
values[k].extend(self.form.getlist(k))
return values
class DjangoJsonRequest(JsonRequest):
def __init__(self, request: HttpRequest):
super().__init__(request.method, request.build_absolute_uri(), None, request.headers)
self._request = request
@cached_property
def data(self):
return json_loads(self._request.body)

View File

@@ -0,0 +1,75 @@
import functools
from django.http import JsonResponse
from authlib.oauth2 import (
OAuth2Error,
ResourceProtector as _ResourceProtector,
)
from authlib.oauth2.rfc6749 import (
MissingAuthorizationError,
)
from authlib.oauth2.rfc6750 import (
BearerTokenValidator as _BearerTokenValidator
)
from .requests import DjangoJsonRequest
from .signals import token_authenticated
class ResourceProtector(_ResourceProtector):
def acquire_token(self, request, scopes=None, **kwargs):
"""A method to acquire current valid token with the given scope.
:param request: Django HTTP request instance
:param scopes: a list of scope values
:return: token object
"""
req = DjangoJsonRequest(request)
# backward compatibility
kwargs['scopes'] = scopes
for claim in kwargs:
if isinstance(kwargs[claim], str):
kwargs[claim] = [kwargs[claim]]
token = self.validate_request(request=req, **kwargs)
token_authenticated.send(sender=self.__class__, token=token)
return token
def __call__(self, scopes=None, optional=False, **kwargs):
claims = kwargs
# backward compatibility
claims['scopes'] = scopes
def wrapper(f):
@functools.wraps(f)
def decorated(request, *args, **kwargs):
try:
token = self.acquire_token(request, **claims)
request.oauth_token = token
except MissingAuthorizationError as error:
if optional:
request.oauth_token = None
return f(request, *args, **kwargs)
return return_error_response(error)
except OAuth2Error as error:
return return_error_response(error)
return f(request, *args, **kwargs)
return decorated
return wrapper
class BearerTokenValidator(_BearerTokenValidator):
def __init__(self, token_model, realm=None, **extra_attributes):
self.token_model = token_model
super().__init__(realm, **extra_attributes)
def authenticate_token(self, token_string):
try:
return self.token_model.objects.get(access_token=token_string)
except self.token_model.DoesNotExist:
return None
def return_error_response(error):
body = dict(error.get_body())
resp = JsonResponse(body, status=error.status_code)
headers = error.get_headers()
for k, v in headers:
resp[k] = v
return resp

View File

@@ -0,0 +1,11 @@
from django.dispatch import Signal
#: signal when client is authenticated
client_authenticated = Signal()
#: signal when token is revoked
token_revoked = Signal()
#: signal when token is authenticated
token_authenticated = Signal()

View File

@@ -0,0 +1,51 @@
from werkzeug.local import LocalProxy
from .integration import FlaskIntegration, token_update
from .apps import FlaskOAuth1App, FlaskOAuth2App
from ..base_client import BaseOAuth, OAuthError
class OAuth(BaseOAuth):
oauth1_client_cls = FlaskOAuth1App
oauth2_client_cls = FlaskOAuth2App
framework_integration_cls = FlaskIntegration
def __init__(self, app=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token)
self.app = app
if app:
self.init_app(app)
def init_app(self, app, cache=None, fetch_token=None, update_token=None):
"""Initialize lazy for Flask app. This is usually used for Flask application
factory pattern.
"""
self.app = app
if cache is not None:
self.cache = cache
if fetch_token:
self.fetch_token = fetch_token
if update_token:
self.update_token = update_token
app.extensions = getattr(app, 'extensions', {})
app.extensions['authlib.integrations.flask_client'] = self
def create_client(self, name):
if not self.app:
raise RuntimeError('OAuth is not init with Flask app.')
return super().create_client(name)
def register(self, name, overwrite=False, **kwargs):
self._registry[name] = (overwrite, kwargs)
if self.app:
return self.create_client(name)
return LocalProxy(lambda: self.create_client(name))
__all__ = [
'OAuth', 'FlaskIntegration',
'FlaskOAuth1App', 'FlaskOAuth2App',
'token_update', 'OAuthError',
]

View File

@@ -0,0 +1,107 @@
from flask import g, redirect, request, session
from ..requests_client import OAuth1Session, OAuth2Session
from ..base_client import (
BaseApp, OAuthError,
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
)
class FlaskAppMixin:
@property
def token(self):
attr = f'_oauth_token_{self.name}'
token = g.get(attr)
if token:
return token
if self._fetch_token:
token = self._fetch_token()
self.token = token
return token
@token.setter
def token(self, token):
attr = f'_oauth_token_{self.name}'
setattr(g, attr, token)
def _get_requested_token(self, *args, **kwargs):
return self.token
def save_authorize_data(self, **kwargs):
state = kwargs.pop('state', None)
if state:
self.framework.set_state_data(session, state, kwargs)
else:
raise RuntimeError('Missing state value')
def authorize_redirect(self, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
rv = self.create_authorization_url(redirect_uri, **kwargs)
self.save_authorize_data(redirect_uri=redirect_uri, **rv)
return redirect(rv['url'])
class FlaskOAuth1App(FlaskAppMixin, OAuth1Mixin, BaseApp):
client_cls = OAuth1Session
def authorize_access_token(self, **kwargs):
"""Fetch access token in one step.
:return: A token dict.
"""
params = request.args.to_dict(flat=True)
state = params.get('oauth_token')
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = self.framework.get_state_data(session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params['request_token'] = data['request_token']
params.update(kwargs)
self.framework.clear_state_data(session, state)
token = self.fetch_access_token(**params)
self.token = token
return token
class FlaskOAuth2App(FlaskAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
client_cls = OAuth2Session
def authorize_access_token(self, **kwargs):
"""Fetch access token in one step.
:return: A token dict.
"""
if request.method == 'GET':
error = request.args.get('error')
if error:
description = request.args.get('error_description')
raise OAuthError(error=error, description=description)
params = {
'code': request.args.get('code'),
'state': request.args.get('state'),
}
else:
params = {
'code': request.form.get('code'),
'state': request.form.get('state'),
}
claims_options = kwargs.pop('claims_options', None)
state_data = self.framework.get_state_data(session, params.get('state'))
self.framework.clear_state_data(session, params.get('state'))
params = self._format_state_params(state_data, params)
token = self.fetch_access_token(**params, **kwargs)
self.token = token
if 'id_token' in token and 'nonce' in state_data:
userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
token['userinfo'] = userinfo
return token

View File

@@ -0,0 +1,28 @@
from flask import current_app
from flask.signals import Namespace
from ..base_client import FrameworkIntegration
_signal = Namespace()
#: signal when token is updated
token_update = _signal.signal('token_update')
class FlaskIntegration(FrameworkIntegration):
def update_token(self, token, refresh_token=None, access_token=None):
token_update.send(
current_app,
name=self.name,
token=token,
refresh_token=refresh_token,
access_token=access_token,
)
@staticmethod
def load_config(oauth, name, params):
rv = {}
for k in params:
conf_key = f'{name}_{k}'.upper()
v = oauth.app.config.get(conf_key, None)
if v is not None:
rv[k] = v
return rv

View File

@@ -0,0 +1,9 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import ResourceProtector, current_credential
from .cache import (
register_nonce_hooks,
register_temporary_credential_hooks,
create_exists_nonce_func,
)

View File

@@ -0,0 +1,182 @@
import logging
from werkzeug.utils import import_string
from flask import Response
from flask import request as flask_req
from authlib.oauth1 import (
OAuth1Request,
AuthorizationServer as _AuthorizationServer,
)
from authlib.common.security import generate_token
from authlib.common.urls import url_encode
log = logging.getLogger(__name__)
class AuthorizationServer(_AuthorizationServer):
"""Flask implementation of :class:`authlib.rfc5849.AuthorizationServer`.
Initialize it with Flask app instance, client model class and cache::
server = AuthorizationServer(app=app, query_client=query_client)
# or initialize lazily
server = AuthorizationServer()
server.init_app(app, query_client=query_client)
:param app: A Flask app instance
:param query_client: A function to get client by client_id. The client
model class MUST implement the methods described by
:class:`~authlib.oauth1.rfc5849.ClientMixin`.
:param token_generator: A function to generate token
"""
def __init__(self, app=None, query_client=None, token_generator=None):
self.app = app
self.query_client = query_client
self.token_generator = token_generator
self._hooks = {
'exists_nonce': None,
'create_temporary_credential': None,
'get_temporary_credential': None,
'delete_temporary_credential': None,
'create_authorization_verifier': None,
'create_token_credential': None,
}
if app is not None:
self.init_app(app)
def init_app(self, app, query_client=None, token_generator=None):
if query_client is not None:
self.query_client = query_client
if token_generator is not None:
self.token_generator = token_generator
if self.token_generator is None:
self.token_generator = self.create_token_generator(app)
methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS')
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self.app = app
def register_hook(self, name, func):
if name not in self._hooks:
raise ValueError('Invalid "name" of hook')
self._hooks[name] = func
def create_token_generator(self, app):
token_generator = app.config.get('OAUTH1_TOKEN_GENERATOR')
if isinstance(token_generator, str):
token_generator = import_string(token_generator)
else:
length = app.config.get('OAUTH1_TOKEN_LENGTH', 42)
def token_generator():
return generate_token(length)
secret_generator = app.config.get('OAUTH1_TOKEN_SECRET_GENERATOR')
if isinstance(secret_generator, str):
secret_generator = import_string(secret_generator)
else:
length = app.config.get('OAUTH1_TOKEN_SECRET_LENGTH', 48)
def secret_generator():
return generate_token(length)
def create_token():
return {
'oauth_token': token_generator(),
'oauth_token_secret': secret_generator()
}
return create_token
def get_client_by_id(self, client_id):
return self.query_client(client_id)
def exists_nonce(self, nonce, request):
func = self._hooks['exists_nonce']
if callable(func):
timestamp = request.timestamp
client_id = request.client_id
token = request.token
return func(nonce, timestamp, client_id, token)
raise RuntimeError('"exists_nonce" hook is required.')
def create_temporary_credential(self, request):
func = self._hooks['create_temporary_credential']
if callable(func):
token = self.token_generator()
return func(token, request.client_id, request.redirect_uri)
raise RuntimeError(
'"create_temporary_credential" hook is required.'
)
def get_temporary_credential(self, request):
func = self._hooks['get_temporary_credential']
if callable(func):
return func(request.token)
raise RuntimeError(
'"get_temporary_credential" hook is required.'
)
def delete_temporary_credential(self, request):
func = self._hooks['delete_temporary_credential']
if callable(func):
return func(request.token)
raise RuntimeError(
'"delete_temporary_credential" hook is required.'
)
def create_authorization_verifier(self, request):
func = self._hooks['create_authorization_verifier']
if callable(func):
verifier = generate_token(36)
func(request.credential, request.user, verifier)
return verifier
raise RuntimeError(
'"create_authorization_verifier" hook is required.'
)
def create_token_credential(self, request):
func = self._hooks['create_token_credential']
if callable(func):
temporary_credential = request.credential
token = self.token_generator()
return func(token, temporary_credential)
raise RuntimeError(
'"create_token_credential" hook is required.'
)
def check_authorization_request(self):
req = self.create_oauth1_request(None)
self.validate_authorization_request(req)
return req
def create_authorization_response(self, request=None, grant_user=None):
return super()\
.create_authorization_response(request, grant_user)
def create_token_response(self, request=None):
return super().create_token_response(request)
def create_oauth1_request(self, request):
if request is None:
request = flask_req
if request.method in ('POST', 'PUT'):
body = request.form.to_dict(flat=True)
else:
body = None
return OAuth1Request(request.method, request.url, body, request.headers)
def handle_response(self, status_code, payload, headers):
return Response(
url_encode(payload),
status=status_code,
headers=headers
)

View File

@@ -0,0 +1,80 @@
from authlib.oauth1 import TemporaryCredential
def register_temporary_credential_hooks(
authorization_server, cache, key_prefix='temporary_credential:'):
"""Register temporary credential related hooks to authorization server.
:param authorization_server: AuthorizationServer instance
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
"""
def create_temporary_credential(token, client_id, redirect_uri):
key = key_prefix + token['oauth_token']
token['client_id'] = client_id
if redirect_uri:
token['oauth_callback'] = redirect_uri
cache.set(key, token, timeout=86400) # cache for one day
return TemporaryCredential(token)
def get_temporary_credential(oauth_token):
if not oauth_token:
return None
key = key_prefix + oauth_token
value = cache.get(key)
if value:
return TemporaryCredential(value)
def delete_temporary_credential(oauth_token):
if oauth_token:
key = key_prefix + oauth_token
cache.delete(key)
def create_authorization_verifier(credential, grant_user, verifier):
key = key_prefix + credential.get_oauth_token()
credential['oauth_verifier'] = verifier
credential['user_id'] = grant_user.get_user_id()
cache.set(key, credential, timeout=86400)
return credential
authorization_server.register_hook(
'create_temporary_credential', create_temporary_credential)
authorization_server.register_hook(
'get_temporary_credential', get_temporary_credential)
authorization_server.register_hook(
'delete_temporary_credential', delete_temporary_credential)
authorization_server.register_hook(
'create_authorization_verifier', create_authorization_verifier)
def create_exists_nonce_func(cache, key_prefix='nonce:', expires=86400):
"""Create an ``exists_nonce`` function that can be used in hooks and
resource protector.
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
:param expires: Expire time for nonce
"""
def exists_nonce(nonce, timestamp, client_id, oauth_token):
key = f'{key_prefix}{nonce}-{timestamp}-{client_id}'
if oauth_token:
key = f'{key}-{oauth_token}'
rv = cache.has(key)
cache.set(key, 1, timeout=expires)
return rv
return exists_nonce
def register_nonce_hooks(
authorization_server, cache, key_prefix='nonce:', expires=86400):
"""Register nonce related hooks to authorization server.
:param authorization_server: AuthorizationServer instance
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
:param expires: Expire time for nonce
"""
exists_nonce = create_exists_nonce_func(cache, key_prefix, expires)
authorization_server.register_hook('exists_nonce', exists_nonce)

View File

@@ -0,0 +1,113 @@
import functools
from flask import g, json, Response
from flask import request as _req
from werkzeug.local import LocalProxy
from authlib.consts import default_json_headers
from authlib.oauth1 import ResourceProtector as _ResourceProtector
from authlib.oauth1.errors import OAuth1Error
class ResourceProtector(_ResourceProtector):
"""A protecting method for resource servers. Initialize a resource
protector with the these method:
1. query_client
2. query_token,
3. exists_nonce
Usually, a ``query_client`` method would look like (if using SQLAlchemy)::
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
A ``query_token`` method accept two parameters, ``client_id`` and ``oauth_token``::
def query_token(client_id, oauth_token):
return Token.query.filter_by(client_id=client_id, oauth_token=oauth_token).first()
And for ``exists_nonce``, if using cache, we have a built-in hook to create this method::
from authlib.integrations.flask_oauth1 import create_exists_nonce_func
exists_nonce = create_exists_nonce_func(cache)
Then initialize the resource protector with those methods::
require_oauth = ResourceProtector(
app, query_client=query_client,
query_token=query_token, exists_nonce=exists_nonce,
)
"""
def __init__(self, app=None, query_client=None,
query_token=None, exists_nonce=None):
self.query_client = query_client
self.query_token = query_token
self._exists_nonce = exists_nonce
self.app = app
if app:
self.init_app(app)
def init_app(self, app, query_client=None, query_token=None,
exists_nonce=None):
if query_client is not None:
self.query_client = query_client
if query_token is not None:
self.query_token = query_token
if exists_nonce is not None:
self._exists_nonce = exists_nonce
methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS')
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self.app = app
def get_client_by_id(self, client_id):
return self.query_client(client_id)
def get_token_credential(self, request):
return self.query_token(request.client_id, request.token)
def exists_nonce(self, nonce, request):
if not self._exists_nonce:
raise RuntimeError('"exists_nonce" function is required.')
timestamp = request.timestamp
client_id = request.client_id
token = request.token
return self._exists_nonce(nonce, timestamp, client_id, token)
def acquire_credential(self):
req = self.validate_request(
_req.method,
_req.url,
_req.form.to_dict(flat=True),
_req.headers
)
g.authlib_server_oauth1_credential = req.credential
return req.credential
def __call__(self, scope=None):
def wrapper(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
self.acquire_credential()
except OAuth1Error as error:
body = dict(error.get_body())
return Response(
json.dumps(body),
status=error.status_code,
headers=default_json_headers,
)
return f(*args, **kwargs)
return decorated
return wrapper
def _get_current_credential():
return g.get('authlib_server_oauth1_credential')
current_credential = LocalProxy(_get_current_credential)

View File

@@ -0,0 +1,12 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import (
ResourceProtector,
current_token,
)
from .signals import (
client_authenticated,
token_authenticated,
token_revoked,
)

View File

@@ -0,0 +1,159 @@
from werkzeug.utils import import_string
from flask import Response, json
from flask import request as flask_req
from authlib.oauth2 import (
AuthorizationServer as _AuthorizationServer,
)
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from authlib.common.security import generate_token
from .requests import FlaskOAuth2Request, FlaskJsonRequest
from .signals import client_authenticated, token_revoked
class AuthorizationServer(_AuthorizationServer):
"""Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
Initialize it with ``query_client``, ``save_token`` methods and Flask
app instance::
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
def save_token(token, request):
if request.user:
user_id = request.user.id
else:
user_id = None
client = request.client
tok = Token(
client_id=client.client_id,
user_id=user.id,
**token
)
db.session.add(tok)
db.session.commit()
server = AuthorizationServer(app, query_client, save_token)
# or initialize lazily
server = AuthorizationServer()
server.init_app(app, query_client, save_token)
"""
def __init__(self, app=None, query_client=None, save_token=None):
super().__init__()
self._query_client = query_client
self._save_token = save_token
self._error_uris = None
if app is not None:
self.init_app(app)
def init_app(self, app, query_client=None, save_token=None):
"""Initialize later with Flask app instance."""
if query_client is not None:
self._query_client = query_client
if save_token is not None:
self._save_token = save_token
self.register_token_generator('default', self.create_bearer_token_generator(app.config))
self.scopes_supported = app.config.get('OAUTH2_SCOPES_SUPPORTED')
self._error_uris = app.config.get('OAUTH2_ERROR_URIS')
def query_client(self, client_id):
return self._query_client(client_id)
def save_token(self, token, request):
return self._save_token(token, request)
def get_error_uri(self, request, error):
if self._error_uris:
uris = dict(self._error_uris)
return uris.get(error.error)
def create_oauth2_request(self, request):
return FlaskOAuth2Request(flask_req)
def create_json_request(self, request):
return FlaskJsonRequest(flask_req)
def handle_response(self, status_code, payload, headers):
if isinstance(payload, dict):
payload = json.dumps(payload)
return Response(payload, status=status_code, headers=headers)
def send_signal(self, name, *args, **kwargs):
if name == 'after_authenticate_client':
client_authenticated.send(self, *args, **kwargs)
elif name == 'after_revoke_token':
token_revoked.send(self, *args, **kwargs)
def create_bearer_token_generator(self, config):
"""Create a generator function for generating ``token`` value. This
method will create a Bearer Token generator with
:class:`authlib.oauth2.rfc6750.BearerToken`.
Configurable settings:
1. OAUTH2_ACCESS_TOKEN_GENERATOR: Boolean or import string, default is True.
2. OAUTH2_REFRESH_TOKEN_GENERATOR: Boolean or import string, default is False.
3. OAUTH2_TOKEN_EXPIRES_IN: Dict or import string, default is None.
By default, it will not generate ``refresh_token``, which can be turn on by
configure ``OAUTH2_REFRESH_TOKEN_GENERATOR``.
Here are some examples of the token generator::
OAUTH2_ACCESS_TOKEN_GENERATOR = 'your_project.generators.gen_token'
# and in module `your_project.generators`, you can define:
def gen_token(client, grant_type, user, scope):
# generate token according to these parameters
token = create_random_token()
return f'{client.id}-{user.id}-{token}'
Here is an example of ``OAUTH2_TOKEN_EXPIRES_IN``::
OAUTH2_TOKEN_EXPIRES_IN = {
'authorization_code': 864000,
'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600,
}
"""
conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True)
access_token_generator = create_token_generator(conf, 42)
conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False)
refresh_token_generator = create_token_generator(conf, 48)
expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN')
expires_generator = create_token_expires_in_generator(expires_conf)
return BearerTokenGenerator(
access_token_generator,
refresh_token_generator,
expires_generator
)
def create_token_expires_in_generator(expires_in_conf=None):
if isinstance(expires_in_conf, str):
return import_string(expires_in_conf)
data = {}
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
if isinstance(expires_in_conf, dict):
data.update(expires_in_conf)
def expires_in(client, grant_type):
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
return expires_in
def create_token_generator(token_generator_conf, length=42):
if callable(token_generator_conf):
return token_generator_conf
if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)
elif token_generator_conf is True:
def token_generator(*args, **kwargs):
return generate_token(length)
return token_generator

View File

@@ -0,0 +1,39 @@
import importlib.metadata
import werkzeug
from werkzeug.exceptions import HTTPException
_version = importlib.metadata.version('werkzeug').split('.')[0]
if _version in ('0', '1'):
class _HTTPException(HTTPException):
def __init__(self, code, body, headers, response=None):
super().__init__(None, response)
self.code = code
self.body = body
self.headers = headers
def get_body(self, environ=None):
return self.body
def get_headers(self, environ=None):
return self.headers
else:
class _HTTPException(HTTPException):
def __init__(self, code, body, headers, response=None):
super().__init__(None, response)
self.code = code
self.body = body
self.headers = headers
def get_body(self, environ=None, scope=None):
return self.body
def get_headers(self, environ=None, scope=None):
return self.headers
def raise_http_exception(status, body, headers):
raise _HTTPException(status, body, headers)

View File

@@ -0,0 +1,40 @@
from collections import defaultdict
from functools import cached_property
from flask.wrappers import Request
from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest
class FlaskOAuth2Request(OAuth2Request):
def __init__(self, request: Request):
super().__init__(request.method, request.url, None, request.headers)
self._request = request
@property
def args(self):
return self._request.args
@property
def form(self):
return self._request.form
@property
def data(self):
return self._request.values
@cached_property
def datalist(self):
values = defaultdict(list)
for k in self.data:
values[k].extend(self.data.getlist(k))
return values
class FlaskJsonRequest(JsonRequest):
def __init__(self, request: Request):
super().__init__(request.method, request.url, None, request.headers)
self._request = request
@property
def data(self):
return self._request.get_json()

View File

@@ -0,0 +1,114 @@
import functools
from contextlib import contextmanager
from flask import g, json
from flask import request as _req
from werkzeug.local import LocalProxy
from authlib.oauth2 import (
OAuth2Error,
ResourceProtector as _ResourceProtector
)
from authlib.oauth2.rfc6749 import (
MissingAuthorizationError,
)
from .requests import FlaskJsonRequest
from .signals import token_authenticated
from .errors import raise_http_exception
class ResourceProtector(_ResourceProtector):
"""A protecting method for resource servers. Creating a ``require_oauth``
decorator easily with ResourceProtector::
from authlib.integrations.flask_oauth2 import ResourceProtector
require_oauth = ResourceProtector()
# add bearer token validator
from authlib.oauth2.rfc6750 import BearerTokenValidator
from project.models import Token
class MyBearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.query.filter_by(access_token=token_string).first()
require_oauth.register_token_validator(MyBearerTokenValidator())
# protect resource with require_oauth
@app.route('/user')
@require_oauth(['profile'])
def user_profile():
user = User.get(current_token.user_id)
return jsonify(user.to_dict())
"""
def raise_error_response(self, error):
"""Raise HTTPException for OAuth2Error. Developers can re-implement
this method to customize the error response.
:param error: OAuth2Error
:raise: HTTPException
"""
status = error.status_code
body = json.dumps(dict(error.get_body()))
headers = error.get_headers()
raise_http_exception(status, body, headers)
def acquire_token(self, scopes=None, **kwargs):
"""A method to acquire current valid token with the given scope.
:param scopes: a list of scope values
:return: token object
"""
request = FlaskJsonRequest(_req)
# backward compatibility
kwargs['scopes'] = scopes
for claim in kwargs:
if isinstance(kwargs[claim], str):
kwargs[claim] = [kwargs[claim]]
token = self.validate_request(request=request, **kwargs)
token_authenticated.send(self, token=token)
g.authlib_server_oauth2_token = token
return token
@contextmanager
def acquire(self, scopes=None):
"""The with statement of ``require_oauth``. Instead of using a
decorator, you can use a with statement instead::
@app.route('/api/user')
def user_api():
with require_oauth.acquire('profile') as token:
user = User.get(token.user_id)
return jsonify(user.to_dict())
"""
try:
yield self.acquire_token(scopes)
except OAuth2Error as error:
self.raise_error_response(error)
def __call__(self, scopes=None, optional=False, **kwargs):
claims = kwargs
# backward compatibility
claims['scopes'] = scopes
def wrapper(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
self.acquire_token(**claims)
except MissingAuthorizationError as error:
if optional:
return f(*args, **kwargs)
self.raise_error_response(error)
except OAuth2Error as error:
self.raise_error_response(error)
return f(*args, **kwargs)
return decorated
return wrapper
def _get_current_token():
return g.get('authlib_server_oauth2_token')
current_token = LocalProxy(_get_current_token)

View File

@@ -0,0 +1,12 @@
from flask.signals import Namespace
_signal = Namespace()
#: signal when client is authenticated
client_authenticated = _signal.signal('client_authenticated')
#: signal when token is revoked
token_revoked = _signal.signal('token_revoked')
#: signal when token is authenticated
token_authenticated = _signal.signal('token_authenticated')

View File

@@ -0,0 +1,25 @@
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_RSA_SHA1,
SIGNATURE_PLAINTEXT,
SIGNATURE_TYPE_HEADER,
SIGNATURE_TYPE_QUERY,
SIGNATURE_TYPE_BODY,
)
from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client
from .oauth2_client import (
OAuth2Auth, OAuth2Client, OAuth2ClientAuth,
AsyncOAuth2Client,
)
from .assertion_client import AssertionClient, AsyncAssertionClient
from ..base_client import OAuthError
__all__ = [
'OAuthError',
'OAuth1Auth', 'AsyncOAuth1Client',
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
'OAuth2Auth', 'OAuth2ClientAuth', 'OAuth2Client', 'AsyncOAuth2Client',
'AssertionClient', 'AsyncAssertionClient',
]

View File

@@ -0,0 +1,81 @@
import httpx
from httpx import Response, USE_CLIENT_DEFAULT
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from .utils import extract_client_kwargs
from .oauth2_client import OAuth2Auth
from ..base_client import OAuthError
__all__ = ['AsyncAssertionClient']
class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self, session=None,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
)
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs) -> Response:
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
await self.refresh_token()
auth = self.token_auth
return await super().request(
method, url, auth=auth, **kwargs)
async def _refresh_token(self, data):
resp = await self.request(
'POST', self.token_endpoint, data=data, withhold_token=True)
return self.parse_response_token(resp)
class AssertionClient(_AssertionClient, httpx.Client):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
client_kwargs = extract_client_kwargs(kwargs)
httpx.Client.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
)
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
self.refresh_token()
auth = self.token_auth
return super().request(
method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,102 @@
import typing
import httpx
from httpx import Auth, Request, Response
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_TYPE_HEADER,
)
from authlib.common.encoding import to_unicode
from authlib.oauth1 import ClientAuth
from authlib.oauth1.client import OAuth1Client as _OAuth1Client
from .utils import build_request, extract_client_kwargs
from ..base_client import OAuthError
class OAuth1Auth(Auth, ClientAuth):
"""Signs the httpx request using OAuth 1 (RFC5849)"""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield build_request(url=url, headers=headers, body=body, initial_request=request)
class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient):
auth_class = OAuth1Auth
def __init__(self, client_id, client_secret=None,
token=None, token_secret=None,
redirect_uri=None, rsa_key=None, verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False, **kwargs):
_client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self, None,
client_id=client_id, client_secret=client_secret,
token=token, token_secret=token_secret,
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
signature_method=signature_method, signature_type=signature_type,
force_include_body=force_include_body, **kwargs)
async def fetch_access_token(self, url, verifier=None, **kwargs):
"""Method for fetching an access token from the token endpoint.
This is the final step in the OAuth 1 workflow. An access token is
obtained using all previously obtained credentials, including the
verifier from the authorization step.
:param url: Access Token endpoint.
:param verifier: A verifier string to prove authorization was granted.
:param kwargs: Extra parameters to include for fetching access token.
:return: A token dict.
"""
if verifier:
self.auth.verifier = verifier
if not self.auth.verifier:
self.handle_error('missing_verifier', 'Missing "verifier" value')
token = await self._fetch_token(url, **kwargs)
self.auth.verifier = None
return token
async def _fetch_token(self, url, **kwargs):
resp = await self.post(url, **kwargs)
text = await resp.aread()
token = self.parse_response_token(resp.status_code, to_unicode(text))
self.token = token
return token
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
class OAuth1Client(_OAuth1Client, httpx.Client):
auth_class = OAuth1Auth
def __init__(self, client_id, client_secret=None,
token=None, token_secret=None,
redirect_uri=None, rsa_key=None, verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False, **kwargs):
_client_kwargs = extract_client_kwargs(kwargs)
httpx.Client.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self, self,
client_id=client_id, client_secret=client_secret,
token=token, token_secret=token_secret,
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
signature_method=signature_method, signature_type=signature_type,
force_include_body=force_include_body, **kwargs)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

View File

@@ -0,0 +1,220 @@
import typing
from contextlib import asynccontextmanager
import httpx
from httpx import Auth, Request, Response, USE_CLIENT_DEFAULT
from anyio import Lock # Import after httpx so import errors refer to httpx
from authlib.common.urls import url_decode
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
from authlib.oauth2.auth import ClientAuth, TokenAuth
from .utils import HTTPX_CLIENT_KWARGS, build_request
from ..base_client import (
OAuthError,
InvalidTokenError,
MissingTokenError,
UnsupportedTokenTypeError,
)
__all__ = [
'OAuth2Auth', 'OAuth2ClientAuth',
'AsyncOAuth2Client', 'OAuth2Client',
]
class OAuth2Auth(Auth, TokenAuth):
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
try:
url, headers, body = self.prepare(
str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield build_request(url=url, headers=headers, body=body, initial_request=request)
except KeyError as error:
description = f'Unsupported token_type: {str(error)}'
raise UnsupportedTokenTypeError(description=description)
class OAuth2ClientAuth(Auth, ClientAuth):
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield build_request(url=url, headers=headers, body=body, initial_request=request)
class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(self, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, leeway=60, **kwargs):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
# We use a Lock to synchronize coroutines to prevent
# multiple concurrent attempts to refresh the same token
self._token_refresh_lock = Lock()
_OAuth2Client.__init__(
self, session=None,
client_id=client_id, client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, leeway=leeway, **kwargs
)
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
return await super().request(
method, url, auth=auth, **kwargs)
@asynccontextmanager
async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
async with super().stream(
method, url, auth=auth, **kwargs) as resp:
yield resp
async def ensure_active_token(self, token):
async with self._token_refresh_lock:
if self.token.is_expired(leeway=self.leeway):
refresh_token = token.get('refresh_token')
url = self.metadata.get('token_endpoint')
if refresh_token and url:
await self.refresh_token(url, refresh_token=refresh_token)
elif self.metadata.get('grant_type') == 'client_credentials':
access_token = token['access_token']
new_token = await self.fetch_token(url, grant_type='client_credentials')
if self.update_token:
await self.update_token(new_token, access_token=access_token)
else:
raise InvalidTokenError()
async def _fetch_token(self, url, body='', headers=None, auth=USE_CLIENT_DEFAULT,
method='POST', **kwargs):
if method.upper() == 'POST':
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers,
auth=auth, **kwargs)
else:
if '?' in url:
url = '&'.join([url, body])
else:
url = '?'.join([url, body])
resp = await self.get(url, headers=headers, auth=auth, **kwargs)
for hook in self.compliance_hook['access_token_response']:
resp = hook(resp)
return self.parse_response_token(resp)
async def _refresh_token(self, url, refresh_token=None, body='',
headers=None, auth=USE_CLIENT_DEFAULT, **kwargs):
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers,
auth=auth, **kwargs)
for hook in self.compliance_hook['refresh_token_response']:
resp = hook(resp)
token = self.parse_response_token(resp)
if 'refresh_token' not in token:
self.token['refresh_token'] = refresh_token
if self.update_token:
await self.update_token(self.token, refresh_token=refresh_token)
return self.token
def _http_post(self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs):
return self.post(
url, data=dict(url_decode(body)),
headers=headers, auth=auth, **kwargs)
class OAuth2Client(_OAuth2Client, httpx.Client):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(self, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, **kwargs):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
httpx.Client.__init__(self, **client_kwargs)
_OAuth2Client.__init__(
self, session=self,
client_id=client_id, client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().request(
method, url, auth=auth, **kwargs)
def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().stream(
method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,30 @@
from httpx import Request
HTTPX_CLIENT_KWARGS = [
'headers', 'cookies', 'verify', 'cert', 'http1', 'http2',
'proxies', 'timeout', 'follow_redirects', 'limits', 'max_redirects',
'event_hooks', 'base_url', 'transport', 'app', 'trust_env',
]
def extract_client_kwargs(kwargs):
client_kwargs = {}
for k in HTTPX_CLIENT_KWARGS:
if k in kwargs:
client_kwargs[k] = kwargs.pop(k)
return client_kwargs
def build_request(url, headers, body, initial_request: Request) -> Request:
"""Make sure that all the data from initial request is passed to the updated object"""
updated_request = Request(
method=initial_request.method,
url=url,
headers=headers,
content=body
)
if hasattr(initial_request, 'extensions'):
updated_request.extensions = initial_request.extensions
return updated_request

View File

@@ -0,0 +1,22 @@
from .oauth1_session import OAuth1Session, OAuth1Auth
from .oauth2_session import OAuth2Session, OAuth2Auth
from .assertion_session import AssertionSession
from ..base_client import OAuthError
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_RSA_SHA1,
SIGNATURE_PLAINTEXT,
SIGNATURE_TYPE_HEADER,
SIGNATURE_TYPE_QUERY,
SIGNATURE_TYPE_BODY,
)
__all__ = [
'OAuthError',
'OAuth1Session', 'OAuth1Auth',
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
'OAuth2Session', 'OAuth2Auth',
'AssertionSession',
]

View File

@@ -0,0 +1,47 @@
from requests import Session
from authlib.oauth2.rfc7521 import AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from .oauth2_session import OAuth2Auth
from .utils import update_session_configure
class AssertionAuth(OAuth2Auth):
def ensure_active_token(self):
if self.client and (not self.token or self.token.is_expired(self.client.leeway)):
return self.client.refresh_token()
class AssertionSession(AssertionClient, Session):
"""Constructs a new Assertion Framework for OAuth 2.0 Authorization Grants
per RFC7521_.
.. _RFC7521: https://tools.ietf.org/html/rfc7521
"""
token_auth_class = AssertionAuth
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, default_timeout=None,
leeway=60, **kwargs):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, leeway=leeway, **kwargs
)
def request(self, method, url, withhold_token=False, auth=None, **kwargs):
"""Send request with auto refresh token feature."""
if self.default_timeout:
kwargs.setdefault('timeout', self.default_timeout)
if not withhold_token and auth is None:
auth = self.token_auth
return super().request(
method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,59 @@
from requests import Session
from requests.auth import AuthBase
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_TYPE_HEADER,
)
from authlib.common.encoding import to_native
from authlib.oauth1 import ClientAuth
from authlib.oauth1.client import OAuth1Client
from ..base_client import OAuthError
from .utils import update_session_configure
class OAuth1Auth(AuthBase, ClientAuth):
"""Signs the request using OAuth 1 (RFC5849)"""
def __call__(self, req):
url, headers, body = self.prepare(
req.method, req.url, req.headers, req.body)
req.url = to_native(url)
req.prepare_headers(headers)
if body:
req.body = body
return req
class OAuth1Session(OAuth1Client, Session):
auth_class = OAuth1Auth
def __init__(self, client_id, client_secret=None,
token=None, token_secret=None,
redirect_uri=None, rsa_key=None, verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False, **kwargs):
Session.__init__(self)
update_session_configure(self, kwargs)
OAuth1Client.__init__(
self, session=self,
client_id=client_id, client_secret=client_secret,
token=token, token_secret=token_secret,
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
signature_method=signature_method, signature_type=signature_type,
force_include_body=force_include_body, **kwargs)
def rebuild_auth(self, prepared_request, response):
"""When being redirected we should always strip Authorization
header, since nonce may not be reused as per OAuth spec.
"""
if 'Authorization' in prepared_request.headers:
# If we get redirected to a new host, we should strip out
# any authentication headers.
prepared_request.headers.pop('Authorization', True)
prepared_request.prepare_auth(self.auth)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

View File

@@ -0,0 +1,113 @@
from requests import Session
from requests.auth import AuthBase
from authlib.oauth2.client import OAuth2Client
from authlib.oauth2.auth import ClientAuth, TokenAuth
from ..base_client import (
OAuthError,
InvalidTokenError,
MissingTokenError,
UnsupportedTokenTypeError,
)
from .utils import update_session_configure
__all__ = ['OAuth2Session', 'OAuth2Auth']
class OAuth2Auth(AuthBase, TokenAuth):
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
def ensure_active_token(self):
if self.client and not self.client.ensure_active_token(self.token):
raise InvalidTokenError()
def __call__(self, req):
self.ensure_active_token()
try:
req.url, req.headers, req.body = self.prepare(
req.url, req.headers, req.body)
except KeyError as error:
description = f'Unsupported token_type: {str(error)}'
raise UnsupportedTokenTypeError(description=description)
return req
class OAuth2ClientAuth(AuthBase, ClientAuth):
"""Attaches OAuth Client Authentication to the given Request object.
"""
def __call__(self, req):
req.url, req.headers, req.body = self.prepare(
req.method, req.url, req.headers, req.body
)
return req
class OAuth2Session(OAuth2Client, Session):
"""Construct a new OAuth 2 client requests session.
:param client_id: Client ID, which you get from client registration.
:param client_secret: Client Secret, which you get from registration.
:param authorization_endpoint: URL of the authorization server's
authorization endpoint.
:param token_endpoint: URL of the authorization server's token endpoint.
:param token_endpoint_auth_method: client authentication method for
token endpoint.
:param revocation_endpoint: URL of the authorization server's OAuth 2.0
revocation endpoint.
:param revocation_endpoint_auth_method: client authentication method for
revocation endpoint.
:param scope: Scope that you needed to access user resources.
:param state: Shared secret to prevent CSRF attack.
:param redirect_uri: Redirect URI you registered as callback.
:param token: A dict of token attributes such as ``access_token``,
``token_type`` and ``expires_at``.
:param token_placement: The place to put token in HTTP request. Available
values: "header", "body", "uri".
:param update_token: A function for you to update token. It accept a
:class:`OAuth2Token` as parameter.
:param leeway: Time window in seconds before the actual expiration of the
authentication token, that the token is considered expired and will
be refreshed.
:param default_timeout: If settled, every requests will have a default timeout.
"""
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
SESSION_REQUEST_PARAMS = (
'allow_redirects', 'timeout', 'cookies', 'files',
'proxies', 'hooks', 'stream', 'verify', 'cert', 'json'
)
def __init__(self, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, state=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, leeway=60, default_timeout=None, **kwargs):
Session.__init__(self)
self.default_timeout = default_timeout
update_session_configure(self, kwargs)
OAuth2Client.__init__(
self, session=self,
client_id=client_id, client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, state=state, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, leeway=leeway, **kwargs
)
def fetch_access_token(self, url=None, **kwargs):
"""Alias for fetch_token."""
return self.fetch_token(url, **kwargs)
def request(self, method, url, withhold_token=False, auth=None, **kwargs):
"""Send request with auto refresh token feature (if available)."""
if self.default_timeout:
kwargs.setdefault('timeout', self.default_timeout)
if not withhold_token and auth is None:
if not self.token:
raise MissingTokenError()
auth = self.token_auth
return super().request(
method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,10 @@
REQUESTS_SESSION_KWARGS = [
'proxies', 'hooks', 'stream', 'verify', 'cert',
'max_redirects', 'trust_env',
]
def update_session_configure(session, kwargs):
for k in REQUESTS_SESSION_KWARGS:
if k in kwargs:
setattr(session, k, kwargs.pop(k))

View File

@@ -0,0 +1,17 @@
from .client_mixin import OAuth2ClientMixin
from .tokens_mixins import OAuth2AuthorizationCodeMixin, OAuth2TokenMixin
from .functions import (
create_query_client_func,
create_save_token_func,
create_query_token_func,
create_revocation_endpoint,
create_bearer_token_validator,
)
__all__ = [
'OAuth2ClientMixin', 'OAuth2AuthorizationCodeMixin', 'OAuth2TokenMixin',
'create_query_client_func', 'create_save_token_func',
'create_query_token_func', 'create_revocation_endpoint',
'create_bearer_token_validator',
]

View File

@@ -0,0 +1,138 @@
import secrets
from sqlalchemy import Column, String, Text, Integer
from authlib.common.encoding import json_loads, json_dumps
from authlib.oauth2.rfc6749 import ClientMixin
from authlib.oauth2.rfc6749 import scope_to_list, list_to_scope
class OAuth2ClientMixin(ClientMixin):
client_id = Column(String(48), index=True)
client_secret = Column(String(120))
client_id_issued_at = Column(Integer, nullable=False, default=0)
client_secret_expires_at = Column(Integer, nullable=False, default=0)
_client_metadata = Column('client_metadata', Text)
@property
def client_info(self):
"""Implementation for Client Info in OAuth 2.0 Dynamic Client
Registration Protocol via `Section 3.2.1`_.
.. _`Section 3.2.1`: https://tools.ietf.org/html/rfc7591#section-3.2.1
"""
return dict(
client_id=self.client_id,
client_secret=self.client_secret,
client_id_issued_at=self.client_id_issued_at,
client_secret_expires_at=self.client_secret_expires_at,
)
@property
def client_metadata(self):
if 'client_metadata' in self.__dict__:
return self.__dict__['client_metadata']
if self._client_metadata:
data = json_loads(self._client_metadata)
self.__dict__['client_metadata'] = data
return data
return {}
def set_client_metadata(self, value):
self._client_metadata = json_dumps(value)
if 'client_metadata' in self.__dict__:
del self.__dict__['client_metadata']
@property
def redirect_uris(self):
return self.client_metadata.get('redirect_uris', [])
@property
def token_endpoint_auth_method(self):
return self.client_metadata.get(
'token_endpoint_auth_method',
'client_secret_basic'
)
@property
def grant_types(self):
return self.client_metadata.get('grant_types', [])
@property
def response_types(self):
return self.client_metadata.get('response_types', [])
@property
def client_name(self):
return self.client_metadata.get('client_name')
@property
def client_uri(self):
return self.client_metadata.get('client_uri')
@property
def logo_uri(self):
return self.client_metadata.get('logo_uri')
@property
def scope(self):
return self.client_metadata.get('scope', '')
@property
def contacts(self):
return self.client_metadata.get('contacts', [])
@property
def tos_uri(self):
return self.client_metadata.get('tos_uri')
@property
def policy_uri(self):
return self.client_metadata.get('policy_uri')
@property
def jwks_uri(self):
return self.client_metadata.get('jwks_uri')
@property
def jwks(self):
return self.client_metadata.get('jwks', [])
@property
def software_id(self):
return self.client_metadata.get('software_id')
@property
def software_version(self):
return self.client_metadata.get('software_version')
def get_client_id(self):
return self.client_id
def get_default_redirect_uri(self):
if self.redirect_uris:
return self.redirect_uris[0]
def get_allowed_scope(self, scope):
if not scope:
return ''
allowed = set(self.scope.split())
scopes = scope_to_list(scope)
return list_to_scope([s for s in scopes if s in allowed])
def check_redirect_uri(self, redirect_uri):
return redirect_uri in self.redirect_uris
def check_client_secret(self, client_secret):
return secrets.compare_digest(self.client_secret, client_secret)
def check_endpoint_auth_method(self, method, endpoint):
if endpoint == 'token':
return self.token_endpoint_auth_method == method
# TODO
return True
def check_response_type(self, response_type):
return response_type in self.response_types
def check_grant_type(self, grant_type):
return grant_type in self.grant_types

View File

@@ -0,0 +1,101 @@
import time
def create_query_client_func(session, client_model):
"""Create an ``query_client`` function that can be used in authorization
server.
:param session: SQLAlchemy session
:param client_model: Client model class
"""
def query_client(client_id):
q = session.query(client_model)
return q.filter_by(client_id=client_id).first()
return query_client
def create_save_token_func(session, token_model):
"""Create an ``save_token`` function that can be used in authorization
server.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
def save_token(token, request):
if request.user:
user_id = request.user.get_user_id()
else:
user_id = None
client = request.client
item = token_model(
client_id=client.client_id,
user_id=user_id,
**token
)
session.add(item)
session.commit()
return save_token
def create_query_token_func(session, token_model):
"""Create an ``query_token`` function for revocation, introspection
token endpoints.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
def query_token(token, token_type_hint):
q = session.query(token_model)
if token_type_hint == 'access_token':
return q.filter_by(access_token=token).first()
elif token_type_hint == 'refresh_token':
return q.filter_by(refresh_token=token).first()
# without token_type_hint
item = q.filter_by(access_token=token).first()
if item:
return item
return q.filter_by(refresh_token=token).first()
return query_token
def create_revocation_endpoint(session, token_model):
"""Create a revocation endpoint class with SQLAlchemy session
and token model.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
from authlib.oauth2.rfc7009 import RevocationEndpoint
query_token = create_query_token_func(session, token_model)
class _RevocationEndpoint(RevocationEndpoint):
def query_token(self, token, token_type_hint):
return query_token(token, token_type_hint)
def revoke_token(self, token, request):
now = int(time.time())
hint = request.form.get('token_type_hint')
token.access_token_revoked_at = now
if hint != 'access_token':
token.refresh_token_revoked_at = now
session.add(token)
session.commit()
return _RevocationEndpoint
def create_bearer_token_validator(session, token_model):
"""Create an bearer token validator class with SQLAlchemy session
and token model.
:param session: SQLAlchemy session
:param token_model: Token model class
"""
from authlib.oauth2.rfc6750 import BearerTokenValidator
class _BearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
q = session.query(token_model)
return q.filter_by(access_token=token_string).first()
return _BearerTokenValidator

View File

@@ -0,0 +1,70 @@
import time
from sqlalchemy import Column, String, Text, Integer
from authlib.oauth2.rfc6749 import (
TokenMixin,
AuthorizationCodeMixin,
)
class OAuth2AuthorizationCodeMixin(AuthorizationCodeMixin):
code = Column(String(120), unique=True, nullable=False)
client_id = Column(String(48))
redirect_uri = Column(Text, default='')
response_type = Column(Text, default='')
scope = Column(Text, default='')
nonce = Column(Text)
auth_time = Column(
Integer, nullable=False,
default=lambda: int(time.time())
)
code_challenge = Column(Text)
code_challenge_method = Column(String(48))
def is_expired(self):
return self.auth_time + 300 < time.time()
def get_redirect_uri(self):
return self.redirect_uri
def get_scope(self):
return self.scope
def get_auth_time(self):
return self.auth_time
def get_nonce(self):
return self.nonce
class OAuth2TokenMixin(TokenMixin):
client_id = Column(String(48))
token_type = Column(String(40))
access_token = Column(String(255), unique=True, nullable=False)
refresh_token = Column(String(255), index=True)
scope = Column(Text, default='')
issued_at = Column(
Integer, nullable=False, default=lambda: int(time.time())
)
access_token_revoked_at = Column(Integer, nullable=False, default=0)
refresh_token_revoked_at = Column(Integer, nullable=False, default=0)
expires_in = Column(Integer, nullable=False, default=0)
def check_client(self, client):
return self.client_id == client.get_client_id()
def get_scope(self):
return self.scope
def get_expires_in(self):
return self.expires_in
def is_revoked(self):
return self.access_token_revoked_at or self.refresh_token_revoked_at
def is_expired(self):
if not self.expires_in:
return False
expires_at = self.issued_at + self.expires_in
return expires_at < time.time()

View File

@@ -0,0 +1,22 @@
# flake8: noqa
from ..base_client import BaseOAuth, OAuthError
from .integration import StarletteIntegration
from .apps import StarletteOAuth1App, StarletteOAuth2App
class OAuth(BaseOAuth):
oauth1_client_cls = StarletteOAuth1App
oauth2_client_cls = StarletteOAuth2App
framework_integration_cls = StarletteIntegration
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token)
self.config = config
__all__ = [
'OAuth', 'OAuthError',
'StarletteIntegration', 'StarletteOAuth1App', 'StarletteOAuth2App',
]

Some files were not shown because too many files have changed in this diff Show More