venv added, updated

This commit is contained in:
Norbert
2024-09-13 09:46:28 +02:00
parent 577596d9f3
commit 82af8c809a
4812 changed files with 640223 additions and 2 deletions

View File

@@ -0,0 +1,17 @@
"""
authlib
~~~~~~~
The ultimate Python library in building OAuth 1.0, OAuth 2.0 and OpenID
Connect clients and providers. It covers from low level specification
implementation to high level framework integrations.
:copyright: (c) 2017 by Hsiaoming Yang.
:license: BSD, see LICENSE for more details.
"""
from .consts import version, homepage, author
__version__ = version
__homepage__ = homepage
__author__ = author
__license__ = 'BSD-3-Clause'

View File

@@ -0,0 +1,66 @@
import json
import base64
import struct
def to_bytes(x, charset='utf-8', errors='strict'):
if x is None:
return None
if isinstance(x, bytes):
return x
if isinstance(x, str):
return x.encode(charset, errors)
if isinstance(x, (int, float)):
return str(x).encode(charset, errors)
return bytes(x)
def to_unicode(x, charset='utf-8', errors='strict'):
if x is None or isinstance(x, str):
return x
if isinstance(x, bytes):
return x.decode(charset, errors)
return str(x)
def to_native(x, encoding='ascii'):
if isinstance(x, str):
return x
return x.decode(encoding)
def json_loads(s):
return json.loads(s)
def json_dumps(data, ensure_ascii=False):
return json.dumps(data, ensure_ascii=ensure_ascii, separators=(',', ':'))
def urlsafe_b64decode(s):
s += b'=' * (-len(s) % 4)
return base64.urlsafe_b64decode(s)
def urlsafe_b64encode(s):
return base64.urlsafe_b64encode(s).rstrip(b'=')
def base64_to_int(s):
data = urlsafe_b64decode(to_bytes(s, charset='ascii'))
buf = struct.unpack('%sB' % len(data), data)
return int(''.join(["%02x" % byte for byte in buf]), 16)
def int_to_base64(num):
if num < 0:
raise ValueError('Must be a positive integer')
s = num.to_bytes((num.bit_length() + 7) // 8, 'big', signed=False)
return to_unicode(urlsafe_b64encode(s))
def json_b64encode(text):
if isinstance(text, dict):
text = json_dumps(text)
return urlsafe_b64encode(to_bytes(text))

View File

@@ -0,0 +1,63 @@
from authlib.consts import default_json_headers
class AuthlibBaseError(Exception):
"""Base Exception for all errors in Authlib."""
#: short-string error code
error = None
#: long-string to describe this error
description = ''
#: web page that describes this error
uri = None
def __init__(self, error=None, description=None, uri=None):
if error is not None:
self.error = error
if description is not None:
self.description = description
if uri is not None:
self.uri = uri
message = f'{self.error}: {self.description}'
super().__init__(message)
def __repr__(self):
return f'<{self.__class__.__name__} "{self.error}">'
class AuthlibHTTPError(AuthlibBaseError):
#: HTTP status code
status_code = 400
def __init__(self, error=None, description=None, uri=None,
status_code=None):
super().__init__(error, description, uri)
if status_code is not None:
self.status_code = status_code
def get_error_description(self):
return self.description
def get_body(self):
error = [('error', self.error)]
if self.description:
error.append(('error_description', self.description))
if self.uri:
error.append(('error_uri', self.uri))
return error
def get_headers(self):
return default_json_headers[:]
def __call__(self, uri=None):
self.uri = uri
body = dict(self.get_body())
headers = self.get_headers()
return self.status_code, body, headers
class ContinueIteration(AuthlibBaseError):
pass

View File

@@ -0,0 +1,19 @@
import os
import string
import random
UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits
def generate_token(length=30, chars=UNICODE_ASCII_CHARACTER_SET):
rand = random.SystemRandom()
return ''.join(rand.choice(chars) for _ in range(length))
def is_secure_transport(uri):
"""Check if the uri is over ssl."""
if os.getenv('AUTHLIB_INSECURE_TRANSPORT'):
return True
uri = uri.lower()
return uri.startswith(('https://', 'http://localhost:'))

View File

@@ -0,0 +1,146 @@
"""
authlib.util.urls
~~~~~~~~~~~~~~~~~
Wrapper functions for URL encoding and decoding.
"""
import re
from urllib.parse import quote as _quote
from urllib.parse import unquote as _unquote
from urllib.parse import urlencode as _urlencode
import urllib.parse as urlparse
from .encoding import to_unicode, to_bytes
always_safe = (
'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
'abcdefghijklmnopqrstuvwxyz'
'0123456789_.-'
)
urlencoded = set(always_safe) | set('=&;:%+~,*@!()/?')
INVALID_HEX_PATTERN = re.compile(r'%[^0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]')
def url_encode(params):
encoded = []
for k, v in params:
encoded.append((to_bytes(k), to_bytes(v)))
return to_unicode(_urlencode(encoded))
def url_decode(query):
"""Decode a query string in x-www-form-urlencoded format into a sequence
of two-element tuples.
Unlike urlparse.parse_qsl(..., strict_parsing=True) urldecode will enforce
correct formatting of the query string by validation. If validation fails
a ValueError will be raised. urllib.parse_qsl will only raise errors if
any of name-value pairs omits the equals sign.
"""
# Check if query contains invalid characters
if query and not set(query) <= urlencoded:
error = ("Error trying to decode a non urlencoded string. "
"Found invalid characters: %s "
"in the string: '%s'. "
"Please ensure the request/response body is "
"x-www-form-urlencoded.")
raise ValueError(error % (set(query) - urlencoded, query))
# Check for correctly hex encoded values using a regular expression
# All encoded values begin with % followed by two hex characters
# correct = %00, %A0, %0A, %FF
# invalid = %G0, %5H, %PO
if INVALID_HEX_PATTERN.search(query):
raise ValueError('Invalid hex encoding in query string.')
# We encode to utf-8 prior to parsing because parse_qsl behaves
# differently on unicode input in python 2 and 3.
# Python 2.7
# >>> urlparse.parse_qsl(u'%E5%95%A6%E5%95%A6')
# u'\xe5\x95\xa6\xe5\x95\xa6'
# Python 2.7, non unicode input gives the same
# >>> urlparse.parse_qsl('%E5%95%A6%E5%95%A6')
# '\xe5\x95\xa6\xe5\x95\xa6'
# but now we can decode it to unicode
# >>> urlparse.parse_qsl('%E5%95%A6%E5%95%A6').decode('utf-8')
# u'\u5566\u5566'
# Python 3.3 however
# >>> urllib.parse.parse_qsl(u'%E5%95%A6%E5%95%A6')
# u'\u5566\u5566'
# We want to allow queries such as "c2" whereas urlparse.parse_qsl
# with the strict_parsing flag will not.
params = urlparse.parse_qsl(query, keep_blank_values=True)
# unicode all the things
decoded = []
for k, v in params:
decoded.append((to_unicode(k), to_unicode(v)))
return decoded
def add_params_to_qs(query, params):
"""Extend a query with a list of two-tuples."""
if isinstance(params, dict):
params = params.items()
qs = urlparse.parse_qsl(query, keep_blank_values=True)
qs.extend(params)
return url_encode(qs)
def add_params_to_uri(uri, params, fragment=False):
"""Add a list of two-tuples to the uri query components."""
sch, net, path, par, query, fra = urlparse.urlparse(uri)
if fragment:
fra = add_params_to_qs(fra, params)
else:
query = add_params_to_qs(query, params)
return urlparse.urlunparse((sch, net, path, par, query, fra))
def quote(s, safe=b'/'):
return to_unicode(_quote(to_bytes(s), safe))
def unquote(s):
return to_unicode(_unquote(s))
def quote_url(s):
return quote(s, b'~@#$&()*!+=:;,.?/\'')
def extract_params(raw):
"""Extract parameters and return them as a list of 2-tuples.
Will successfully extract parameters from urlencoded query strings,
dicts, or lists of 2-tuples. Empty strings/dicts/lists will return an
empty list of parameters. Any other input will result in a return
value of None.
"""
if isinstance(raw, (list, tuple)):
try:
raw = dict(raw)
except (TypeError, ValueError):
return None
if isinstance(raw, dict):
params = []
for k, v in raw.items():
params.append((to_unicode(k), to_unicode(v)))
return params
if not raw:
return None
try:
return url_decode(raw)
except ValueError:
return None
def is_valid_url(url):
parsed = urlparse.urlparse(url)
return parsed.scheme and parsed.hostname

View File

@@ -0,0 +1,11 @@
name = 'Authlib'
version = '1.3.2'
author = 'Hsiaoming Yang <me@lepture.com>'
homepage = 'https://authlib.org/'
default_user_agent = f'{name}/{version} (+{homepage})'
default_json_headers = [
('Content-Type', 'application/json'),
('Cache-Control', 'no-store'),
('Pragma', 'no-cache'),
]

View File

@@ -0,0 +1,16 @@
import warnings
class AuthlibDeprecationWarning(DeprecationWarning):
pass
warnings.simplefilter('always', AuthlibDeprecationWarning)
def deprecate(message, version=None, link_uid=None, link_file=None):
if version:
message += f'\nIt will be compatible before version {version}.'
if link_uid and link_file:
message += f'\nRead more <https://git.io/{link_uid}#file-{link_file}-md>'
warnings.warn(AuthlibDeprecationWarning(message), stacklevel=2)

View File

@@ -0,0 +1,18 @@
from .registry import BaseOAuth
from .sync_app import BaseApp, OAuth1Mixin, OAuth2Mixin
from .sync_openid import OpenIDMixin
from .framework_integration import FrameworkIntegration
from .errors import (
OAuthError, MissingRequestTokenError, MissingTokenError,
TokenExpiredError, InvalidTokenError, UnsupportedTokenTypeError,
MismatchingStateError,
)
__all__ = [
'BaseOAuth',
'BaseApp', 'OAuth1Mixin', 'OAuth2Mixin',
'OpenIDMixin', 'FrameworkIntegration',
'OAuthError', 'MissingRequestTokenError', 'MissingTokenError',
'TokenExpiredError', 'InvalidTokenError', 'UnsupportedTokenTypeError',
'MismatchingStateError',
]

View File

@@ -0,0 +1,144 @@
import time
import logging
from authlib.common.urls import urlparse
from .errors import (
MissingRequestTokenError,
MissingTokenError,
)
from .sync_app import OAuth1Base, OAuth2Base
log = logging.getLogger(__name__)
__all__ = ['AsyncOAuth1Mixin', 'AsyncOAuth2Mixin']
class AsyncOAuth1Mixin(OAuth1Base):
async def request(self, method, url, token=None, **kwargs):
async with self._get_oauth_client() as session:
return await _http_request(self, session, method, url, token, kwargs)
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
if not self.authorize_url:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client() as client:
client.redirect_uri = redirect_uri
params = {}
if self.request_token_params:
params.update(self.request_token_params)
request_token = await client.fetch_request_token(self.request_token_url, **params)
log.debug(f'Fetch request token: {request_token!r}')
url = client.create_authorization_url(self.authorize_url, **kwargs)
state = request_token['oauth_token']
return {'url': url, 'request_token': request_token, 'state': state}
async def fetch_access_token(self, request_token=None, **kwargs):
"""Fetch access token in one step.
:param request_token: A previous request token for OAuth 1.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
async with self._get_oauth_client() as client:
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(kwargs)
client.token = token
params = self.access_token_params or {}
token = await client.fetch_access_token(self.access_token_url, **params)
return token
class AsyncOAuth2Mixin(OAuth2Base):
async def _on_update_token(self, token, refresh_token=None, access_token=None):
if self._update_token:
await self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
async def load_server_metadata(self):
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
async with self.client_cls(**self.client_kwargs) as client:
resp = await client.request('GET', self._server_metadata_url, withhold_token=True)
resp.raise_for_status()
metadata = resp.json()
metadata['_loaded_at'] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
async def request(self, method, url, token=None, **kwargs):
metadata = await self.load_server_metadata()
async with self._get_oauth_client(**metadata) as session:
return await _http_request(self, session, method, url, token, kwargs)
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = await self.load_server_metadata()
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client(**metadata) as client:
client.redirect_uri = redirect_uri
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs)
async def fetch_access_token(self, redirect_uri=None, **kwargs):
"""Fetch access token in the final step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = await self.load_server_metadata()
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
async with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
params = {}
if self.access_token_params:
params.update(self.access_token_params)
params.update(kwargs)
token = await client.fetch_token(token_endpoint, **params)
return token
async def _http_request(ctx, session, method, url, token, kwargs):
request = kwargs.pop('request', None)
withhold_token = kwargs.get('withhold_token')
if ctx.api_base_url and not url.startswith(('https://', 'http://')):
url = urlparse.urljoin(ctx.api_base_url, url)
if withhold_token:
return await session.request(method, url, **kwargs)
if token is None and ctx._fetch_token and request:
token = await ctx._fetch_token(request)
if token is None:
raise MissingTokenError()
session.token = token
return await session.request(method, url, **kwargs)

View File

@@ -0,0 +1,79 @@
from authlib.jose import JsonWebToken, JsonWebKey
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
__all__ = ['AsyncOpenIDMixin']
class AsyncOpenIDMixin:
async def fetch_jwk_set(self, force=False):
metadata = await self.load_server_metadata()
jwk_set = metadata.get('jwks')
if jwk_set and not force:
return jwk_set
uri = metadata.get('jwks_uri')
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
async with self.client_cls(**self.client_kwargs) as client:
resp = await client.request('GET', uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()
self.server_metadata['jwks'] = jwk_set
return jwk_set
async def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = await self.load_server_metadata()
resp = await self.get(metadata['userinfo_endpoint'], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
async def parse_id_token(self, token, nonce, claims_options=None):
"""Return an instance of UserInfo from token's ``id_token``."""
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if 'access_token' in token:
claims_params['access_token'] = token['access_token']
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = await self.load_server_metadata()
if claims_options is None and 'issuer' in metadata:
claims_options = {'iss': {'values': [metadata['issuer']]}}
alg_values = metadata.get('id_token_signing_alg_values_supported')
if not alg_values:
alg_values = ['RS256']
jwt = JsonWebToken(alg_values)
jwk_set = await self.fetch_jwk_set()
try:
claims = jwt.decode(
token['id_token'],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
jwk_set = await self.fetch_jwk_set(force=True)
claims = jwt.decode(
token['id_token'],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
# https://github.com/lepture/authlib/issues/259
if claims.get('nonce_supported') is False:
claims.params['nonce'] = None
claims.validate(leeway=120)
return UserInfo(claims)

View File

@@ -0,0 +1,30 @@
from authlib.common.errors import AuthlibBaseError
class OAuthError(AuthlibBaseError):
error = 'oauth_error'
class MissingRequestTokenError(OAuthError):
error = 'missing_request_token'
class MissingTokenError(OAuthError):
error = 'missing_token'
class TokenExpiredError(OAuthError):
error = 'token_expired'
class InvalidTokenError(OAuthError):
error = 'token_invalid'
class UnsupportedTokenTypeError(OAuthError):
error = 'unsupported_token_type'
class MismatchingStateError(OAuthError):
error = 'mismatching_state'
description = 'CSRF Warning! State not equal in request and response.'

View File

@@ -0,0 +1,64 @@
import json
import time
class FrameworkIntegration:
expires_in = 3600
def __init__(self, name, cache=None):
self.name = name
self.cache = cache
def _get_cache_data(self, key):
value = self.cache.get(key)
if not value:
return None
try:
return json.loads(value)
except (TypeError, ValueError):
return None
def _clear_session_state(self, session):
now = time.time()
for key in dict(session):
if '_authlib_' in key:
# TODO: remove in future
session.pop(key)
elif key.startswith('_state_'):
value = session[key]
exp = value.get('exp')
if not exp or exp < now:
session.pop(key)
def get_state_data(self, session, state):
key = f'_state_{self.name}_{state}'
if self.cache:
value = self._get_cache_data(key)
else:
value = session.get(key)
if value:
return value.get('data')
return None
def set_state_data(self, session, state, data):
key = f'_state_{self.name}_{state}'
if self.cache:
self.cache.set(key, json.dumps({'data': data}), self.expires_in)
else:
now = time.time()
session[key] = {'data': data, 'exp': now + self.expires_in}
def clear_state_data(self, session, state):
key = f'_state_{self.name}_{state}'
if self.cache:
self.cache.delete(key)
else:
session.pop(key, None)
self._clear_session_state(session)
def update_token(self, token, refresh_token=None, access_token=None):
raise NotImplementedError()
@staticmethod
def load_config(oauth, name, params):
raise NotImplementedError()

View File

@@ -0,0 +1,131 @@
import functools
from .framework_integration import FrameworkIntegration
__all__ = ['BaseOAuth']
OAUTH_CLIENT_PARAMS = (
'client_id', 'client_secret',
'request_token_url', 'request_token_params',
'access_token_url', 'access_token_params',
'refresh_token_url', 'refresh_token_params',
'authorize_url', 'authorize_params',
'api_base_url', 'client_kwargs',
'server_metadata_url',
)
class BaseOAuth:
"""Registry for oauth clients.
Create an instance for registry::
oauth = OAuth()
"""
oauth1_client_cls = None
oauth2_client_cls = None
framework_integration_cls = FrameworkIntegration
def __init__(self, cache=None, fetch_token=None, update_token=None):
self._registry = {}
self._clients = {}
self.cache = cache
self.fetch_token = fetch_token
self.update_token = update_token
def create_client(self, name):
"""Create or get the given named OAuth client. For instance, the
OAuth registry has ``.register`` a twitter client, developers may
access the client with::
client = oauth.create_client('twitter')
:param: name: Name of the remote application
:return: OAuth remote app
"""
if name in self._clients:
return self._clients[name]
if name not in self._registry:
return None
overwrite, config = self._registry[name]
client_cls = config.pop('client_cls', None)
if client_cls and client_cls.OAUTH_APP_CONFIG:
kwargs = client_cls.OAUTH_APP_CONFIG
kwargs.update(config)
else:
kwargs = config
kwargs = self.generate_client_kwargs(name, overwrite, **kwargs)
framework = self.framework_integration_cls(name, self.cache)
if client_cls:
client = client_cls(framework, name, **kwargs)
elif kwargs.get('request_token_url'):
client = self.oauth1_client_cls(framework, name, **kwargs)
else:
client = self.oauth2_client_cls(framework, name, **kwargs)
self._clients[name] = client
return client
def register(self, name, overwrite=False, **kwargs):
"""Registers a new remote application.
:param name: Name of the remote application.
:param overwrite: Overwrite existing config with framework settings.
:param kwargs: Parameters for :class:`RemoteApp`.
Find parameters for the given remote app class. When a remote app is
registered, it can be accessed with *named* attribute::
oauth.register('twitter', client_id='', ...)
oauth.twitter.get('timeline')
"""
self._registry[name] = (overwrite, kwargs)
return self.create_client(name)
def generate_client_kwargs(self, name, overwrite, **kwargs):
fetch_token = kwargs.pop('fetch_token', None)
update_token = kwargs.pop('update_token', None)
config = self.load_config(name, OAUTH_CLIENT_PARAMS)
if config:
kwargs = _config_client(config, kwargs, overwrite)
if not fetch_token and self.fetch_token:
fetch_token = functools.partial(self.fetch_token, name)
kwargs['fetch_token'] = fetch_token
if not kwargs.get('request_token_url'):
if not update_token and self.update_token:
update_token = functools.partial(self.update_token, name)
kwargs['update_token'] = update_token
return kwargs
def load_config(self, name, params):
return self.framework_integration_cls.load_config(self, name, params)
def __getattr__(self, key):
try:
return object.__getattribute__(self, key)
except AttributeError:
if key in self._registry:
return self.create_client(key)
raise AttributeError('No such client: %s' % key)
def _config_client(config, kwargs, overwrite):
for k in OAUTH_CLIENT_PARAMS:
v = config.get(k, None)
if k not in kwargs:
kwargs[k] = v
elif overwrite and v:
if isinstance(kwargs[k], dict):
kwargs[k].update(v)
else:
kwargs[k] = v
return kwargs

View File

@@ -0,0 +1,348 @@
import time
import logging
from authlib.common.urls import urlparse
from authlib.consts import default_user_agent
from authlib.common.security import generate_token
from .errors import (
MismatchingStateError,
MissingRequestTokenError,
MissingTokenError,
)
log = logging.getLogger(__name__)
class BaseApp:
client_cls = None
OAUTH_APP_CONFIG = None
def request(self, method, url, token=None, **kwargs):
raise NotImplementedError()
def get(self, url, **kwargs):
"""Invoke GET http request.
If ``api_base_url`` configured, shortcut is available::
client.get('users/lepture')
"""
return self.request('GET', url, **kwargs)
def post(self, url, **kwargs):
"""Invoke POST http request.
If ``api_base_url`` configured, shortcut is available::
client.post('timeline', json={'text': 'Hi'})
"""
return self.request('POST', url, **kwargs)
def patch(self, url, **kwargs):
"""Invoke PATCH http request.
If ``api_base_url`` configured, shortcut is available::
client.patch('profile', json={'name': 'Hsiaoming Yang'})
"""
return self.request('PATCH', url, **kwargs)
def put(self, url, **kwargs):
"""Invoke PUT http request.
If ``api_base_url`` configured, shortcut is available::
client.put('profile', json={'name': 'Hsiaoming Yang'})
"""
return self.request('PUT', url, **kwargs)
def delete(self, url, **kwargs):
"""Invoke DELETE http request.
If ``api_base_url`` configured, shortcut is available::
client.delete('posts/123')
"""
return self.request('DELETE', url, **kwargs)
class _RequestMixin:
def _get_requested_token(self, request):
if self._fetch_token and request:
return self._fetch_token(request)
def _send_token_request(self, session, method, url, token, kwargs):
request = kwargs.pop('request', None)
withhold_token = kwargs.get('withhold_token')
if self.api_base_url and not url.startswith(('https://', 'http://')):
url = urlparse.urljoin(self.api_base_url, url)
if withhold_token:
return session.request(method, url, **kwargs)
if token is None:
token = self._get_requested_token(request)
if token is None:
raise MissingTokenError()
session.token = token
return session.request(method, url, **kwargs)
class OAuth1Base:
client_cls = None
def __init__(
self, framework, name=None, fetch_token=None,
client_id=None, client_secret=None,
request_token_url=None, request_token_params=None,
access_token_url=None, access_token_params=None,
authorize_url=None, authorize_params=None,
api_base_url=None, client_kwargs=None, user_agent=None, **kwargs):
self.framework = framework
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.request_token_url = request_token_url
self.request_token_params = request_token_params
self.access_token_url = access_token_url
self.access_token_params = access_token_params
self.authorize_url = authorize_url
self.authorize_params = authorize_params
self.api_base_url = api_base_url
self.client_kwargs = client_kwargs or {}
self._fetch_token = fetch_token
self._user_agent = user_agent or default_user_agent
self._kwargs = kwargs
def _get_oauth_client(self):
session = self.client_cls(self.client_id, self.client_secret, **self.client_kwargs)
session.headers['User-Agent'] = self._user_agent
return session
class OAuth1Mixin(_RequestMixin, OAuth1Base):
def request(self, method, url, token=None, **kwargs):
with self._get_oauth_client() as session:
return self._send_token_request(session, method, url, token, kwargs)
def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
if not self.authorize_url:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
with self._get_oauth_client() as client:
client.redirect_uri = redirect_uri
params = self.request_token_params or {}
request_token = client.fetch_request_token(self.request_token_url, **params)
log.debug(f'Fetch request token: {request_token!r}')
url = client.create_authorization_url(self.authorize_url, **kwargs)
state = request_token['oauth_token']
return {'url': url, 'request_token': request_token, 'state': state}
def fetch_access_token(self, request_token=None, **kwargs):
"""Fetch access token in one step.
:param request_token: A previous request token for OAuth 1.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
with self._get_oauth_client() as client:
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(kwargs)
client.token = token
params = self.access_token_params or {}
token = client.fetch_access_token(self.access_token_url, **params)
return token
class OAuth2Base:
client_cls = None
def __init__(
self, framework, name=None, fetch_token=None, update_token=None,
client_id=None, client_secret=None,
access_token_url=None, access_token_params=None,
authorize_url=None, authorize_params=None,
api_base_url=None, client_kwargs=None, server_metadata_url=None,
compliance_fix=None, client_auth_methods=None, user_agent=None, **kwargs):
self.framework = framework
self.name = name
self.client_id = client_id
self.client_secret = client_secret
self.access_token_url = access_token_url
self.access_token_params = access_token_params
self.authorize_url = authorize_url
self.authorize_params = authorize_params
self.api_base_url = api_base_url
self.client_kwargs = client_kwargs or {}
self.compliance_fix = compliance_fix
self.client_auth_methods = client_auth_methods
self._fetch_token = fetch_token
self._update_token = update_token
self._user_agent = user_agent or default_user_agent
self._server_metadata_url = server_metadata_url
self.server_metadata = kwargs
def _on_update_token(self, token, refresh_token=None, access_token=None):
raise NotImplementedError()
def _get_oauth_client(self, **metadata):
client_kwargs = {}
client_kwargs.update(self.client_kwargs)
client_kwargs.update(metadata)
if self.authorize_url:
client_kwargs['authorization_endpoint'] = self.authorize_url
if self.access_token_url:
client_kwargs['token_endpoint'] = self.access_token_url
session = self.client_cls(
client_id=self.client_id,
client_secret=self.client_secret,
update_token=self._on_update_token,
**client_kwargs
)
if self.client_auth_methods:
for f in self.client_auth_methods:
session.register_client_auth_method(f)
if self.compliance_fix:
self.compliance_fix(session)
session.headers['User-Agent'] = self._user_agent
return session
@staticmethod
def _format_state_params(state_data, params):
if state_data is None:
raise MismatchingStateError()
code_verifier = state_data.get('code_verifier')
if code_verifier:
params['code_verifier'] = code_verifier
redirect_uri = state_data.get('redirect_uri')
if redirect_uri:
params['redirect_uri'] = redirect_uri
return params
@staticmethod
def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs):
rv = {}
if client.code_challenge_method:
code_verifier = kwargs.get('code_verifier')
if not code_verifier:
code_verifier = generate_token(48)
kwargs['code_verifier'] = code_verifier
rv['code_verifier'] = code_verifier
log.debug(f'Using code_verifier: {code_verifier!r}')
scope = kwargs.get('scope', client.scope)
scope = (
(scope if isinstance(scope, (list, tuple)) else scope.split())
if scope
else None
)
if scope and "openid" in scope:
# this is an OpenID Connect service
nonce = kwargs.get('nonce')
if not nonce:
nonce = generate_token(20)
kwargs['nonce'] = nonce
rv['nonce'] = nonce
url, state = client.create_authorization_url(
authorization_endpoint, **kwargs)
rv['url'] = url
rv['state'] = state
return rv
class OAuth2Mixin(_RequestMixin, OAuth2Base):
def _on_update_token(self, token, refresh_token=None, access_token=None):
if callable(self._update_token):
self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
self.framework.update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
def request(self, method, url, token=None, **kwargs):
metadata = self.load_server_metadata()
with self._get_oauth_client(**metadata) as session:
return self._send_token_request(session, method, url, token, kwargs)
def load_server_metadata(self):
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
with self.client_cls(**self.client_kwargs) as session:
resp = session.request('GET', self._server_metadata_url, withhold_token=True)
resp.raise_for_status()
metadata = resp.json()
metadata['_loaded_at'] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = self.load_server_metadata()
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs)
def fetch_access_token(self, redirect_uri=None, **kwargs):
"""Fetch access token in the final step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param kwargs: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = self.load_server_metadata()
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
with self._get_oauth_client(**metadata) as client:
if redirect_uri is not None:
client.redirect_uri = redirect_uri
params = {}
if self.access_token_params:
params.update(self.access_token_params)
params.update(kwargs)
token = client.fetch_token(token_endpoint, **params)
return token

View File

@@ -0,0 +1,82 @@
from authlib.jose import jwt, JsonWebToken, JsonWebKey
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
class OpenIDMixin:
def fetch_jwk_set(self, force=False):
metadata = self.load_server_metadata()
jwk_set = metadata.get('jwks')
if jwk_set and not force:
return jwk_set
uri = metadata.get('jwks_uri')
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
with self.client_cls(**self.client_kwargs) as session:
resp = session.request('GET', uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()
self.server_metadata['jwks'] = jwk_set
return jwk_set
def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = self.load_server_metadata()
resp = self.get(metadata['userinfo_endpoint'], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
"""Return an instance of UserInfo from token's ``id_token``."""
if 'id_token' not in token:
return None
load_key = self.create_load_key()
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if 'access_token' in token:
claims_params['access_token'] = token['access_token']
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = self.load_server_metadata()
if claims_options is None and 'issuer' in metadata:
claims_options = {'iss': {'values': [metadata['issuer']]}}
alg_values = metadata.get('id_token_signing_alg_values_supported')
if alg_values:
_jwt = JsonWebToken(alg_values)
else:
_jwt = jwt
claims = _jwt.decode(
token['id_token'], key=load_key,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
# https://github.com/lepture/authlib/issues/259
if claims.get('nonce_supported') is False:
claims.params['nonce'] = None
claims.validate(leeway=leeway)
return UserInfo(claims)
def create_load_key(self):
def load_key(header, _):
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
try:
return jwk_set.find_by_kid(header.get('kid'))
except ValueError:
# re-try with new jwk set
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
return jwk_set.find_by_kid(header.get('kid'))
return load_key

View File

@@ -0,0 +1,19 @@
# flake8: noqa
from .integration import DjangoIntegration, token_update
from .apps import DjangoOAuth1App, DjangoOAuth2App
from ..base_client import BaseOAuth, OAuthError
class OAuth(BaseOAuth):
oauth1_client_cls = DjangoOAuth1App
oauth2_client_cls = DjangoOAuth2App
framework_integration_cls = DjangoIntegration
__all__ = [
'OAuth',
'DjangoOAuth1App', 'DjangoOAuth2App',
'DjangoIntegration',
'token_update', 'OAuthError',
]

View File

@@ -0,0 +1,87 @@
from django.http import HttpResponseRedirect
from ..requests_client import OAuth1Session, OAuth2Session
from ..base_client import (
BaseApp, OAuthError,
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
)
class DjangoAppMixin:
def save_authorize_data(self, request, **kwargs):
state = kwargs.pop('state', None)
if state:
self.framework.set_state_data(request.session, state, kwargs)
else:
raise RuntimeError('Missing state value')
def authorize_redirect(self, request, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param request: HTTP request instance from Django view.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
rv = self.create_authorization_url(redirect_uri, **kwargs)
self.save_authorize_data(request, redirect_uri=redirect_uri, **rv)
return HttpResponseRedirect(rv['url'])
class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp):
client_cls = OAuth1Session
def authorize_access_token(self, request, **kwargs):
"""Fetch access token in one step.
:param request: HTTP request instance from Django view.
:return: A token dict.
"""
params = request.GET.dict()
state = params.get('oauth_token')
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = self.framework.get_state_data(request.session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params['request_token'] = data['request_token']
params.update(kwargs)
self.framework.clear_state_data(request.session, state)
return self.fetch_access_token(**params)
class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
client_cls = OAuth2Session
def authorize_access_token(self, request, **kwargs):
"""Fetch access token in one step.
:param request: HTTP request instance from Django view.
:return: A token dict.
"""
if request.method == 'GET':
error = request.GET.get('error')
if error:
description = request.GET.get('error_description')
raise OAuthError(error=error, description=description)
params = {
'code': request.GET.get('code'),
'state': request.GET.get('state'),
}
else:
params = {
'code': request.POST.get('code'),
'state': request.POST.get('state'),
}
claims_options = kwargs.pop('claims_options', None)
state_data = self.framework.get_state_data(request.session, params.get('state'))
self.framework.clear_state_data(request.session, params.get('state'))
params = self._format_state_params(state_data, params)
token = self.fetch_access_token(**params, **kwargs)
if 'id_token' in token and 'nonce' in state_data:
userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
token['userinfo'] = userinfo
return token

View File

@@ -0,0 +1,22 @@
from django.conf import settings
from django.dispatch import Signal
from ..base_client import FrameworkIntegration
token_update = Signal()
class DjangoIntegration(FrameworkIntegration):
def update_token(self, token, refresh_token=None, access_token=None):
token_update.send(
sender=self.__class__,
name=self.name,
token=token,
refresh_token=refresh_token,
access_token=access_token,
)
@staticmethod
def load_config(oauth, name, params):
config = getattr(settings, 'AUTHLIB_OAUTH_CLIENTS', None)
if config:
return config.get(name)

View File

@@ -0,0 +1,9 @@
# flake8: noqa
from .authorization_server import (
BaseServer, CacheAuthorizationServer
)
from .resource_protector import ResourceProtector
__all__ = ['BaseServer', 'CacheAuthorizationServer', 'ResourceProtector']

View File

@@ -0,0 +1,125 @@
import logging
from authlib.oauth1 import (
OAuth1Request,
AuthorizationServer as _AuthorizationServer,
)
from authlib.oauth1 import TemporaryCredential
from authlib.common.security import generate_token
from authlib.common.urls import url_encode
from django.core.cache import cache
from django.conf import settings
from django.http import HttpResponse
from .nonce import exists_nonce_in_cache
log = logging.getLogger(__name__)
class BaseServer(_AuthorizationServer):
def __init__(self, client_model, token_model, token_generator=None):
self.client_model = client_model
self.token_model = token_model
if token_generator is None:
def token_generator():
return {
'oauth_token': generate_token(42),
'oauth_token_secret': generate_token(48)
}
self.token_generator = token_generator
self._config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {})
self._nonce_expires_in = self._config.get('nonce_expires_in', 86400)
methods = self._config.get('signature_methods')
if methods:
self.SUPPORTED_SIGNATURE_METHODS = methods
def get_client_by_id(self, client_id):
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def exists_nonce(self, nonce, request):
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
def create_token_credential(self, request):
temporary_credential = request.credential
token = self.token_generator()
item = self.token_model(
oauth_token=token['oauth_token'],
oauth_token_secret=token['oauth_token_secret'],
user_id=temporary_credential.get_user_id(),
client_id=temporary_credential.get_client_id()
)
item.save()
return item
def check_authorization_request(self, request):
req = self.create_oauth1_request(request)
self.validate_authorization_request(req)
return req
def create_oauth1_request(self, request):
if request.method == 'POST':
body = request.POST.dict()
else:
body = None
url = request.build_absolute_uri()
return OAuth1Request(request.method, url, body, request.headers)
def handle_response(self, status_code, payload, headers):
resp = HttpResponse(url_encode(payload), status=status_code)
for k, v in headers:
resp[k] = v
return resp
class CacheAuthorizationServer(BaseServer):
def __init__(self, client_model, token_model, token_generator=None):
super().__init__(
client_model, token_model, token_generator)
self._temporary_expires_in = self._config.get(
'temporary_credential_expires_in', 86400)
self._temporary_credential_key_prefix = self._config.get(
'temporary_credential_key_prefix', 'temporary_credential:')
def create_temporary_credential(self, request):
key_prefix = self._temporary_credential_key_prefix
token = self.token_generator()
client_id = request.client_id
redirect_uri = request.redirect_uri
key = key_prefix + token['oauth_token']
token['client_id'] = client_id
if redirect_uri:
token['oauth_callback'] = redirect_uri
cache.set(key, token, timeout=self._temporary_expires_in)
return TemporaryCredential(token)
def get_temporary_credential(self, request):
if not request.token:
return None
key_prefix = self._temporary_credential_key_prefix
key = key_prefix + request.token
value = cache.get(key)
if value:
return TemporaryCredential(value)
def delete_temporary_credential(self, request):
if request.token:
key_prefix = self._temporary_credential_key_prefix
key = key_prefix + request.token
cache.delete(key)
def create_authorization_verifier(self, request):
key_prefix = self._temporary_credential_key_prefix
verifier = generate_token(36)
credential = request.credential
user = request.user
key = key_prefix + credential.get_oauth_token()
credential['oauth_verifier'] = verifier
credential['user_id'] = user.pk
cache.set(key, credential, timeout=self._temporary_expires_in)
return verifier

View File

@@ -0,0 +1,15 @@
from django.core.cache import cache
def exists_nonce_in_cache(nonce, request, timeout):
key_prefix = 'nonce:'
timestamp = request.timestamp
client_id = request.client_id
token = request.token
key = f'{key_prefix}{nonce}-{timestamp}-{client_id}'
if token:
key = f'{key}-{token}'
rv = bool(cache.get(key))
cache.set(key, 1, timeout=timeout)
return rv

View File

@@ -0,0 +1,64 @@
import functools
from authlib.oauth1.errors import OAuth1Error
from authlib.oauth1 import ResourceProtector as _ResourceProtector
from django.http import JsonResponse
from django.conf import settings
from .nonce import exists_nonce_in_cache
class ResourceProtector(_ResourceProtector):
def __init__(self, client_model, token_model):
self.client_model = client_model
self.token_model = token_model
config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {})
methods = config.get('signature_methods', [])
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self._nonce_expires_in = config.get('nonce_expires_in', 86400)
def get_client_by_id(self, client_id):
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def get_token_credential(self, request):
try:
return self.token_model.objects.get(
client_id=request.client_id,
oauth_token=request.token
)
except self.token_model.DoesNotExist:
return None
def exists_nonce(self, nonce, request):
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
def acquire_credential(self, request):
if request.method in ['POST', 'PUT']:
body = request.POST.dict()
else:
body = None
url = request.build_absolute_uri()
req = self.validate_request(request.method, url, body, request.headers)
return req.credential
def __call__(self, realm=None):
def wrapper(f):
@functools.wraps(f)
def decorated(request, *args, **kwargs):
try:
credential = self.acquire_credential(request)
request.oauth1_credential = credential
except OAuth1Error as error:
body = dict(error.get_body())
resp = JsonResponse(body, status=error.status_code)
resp['Cache-Control'] = 'no-store'
resp['Pragma'] = 'no-cache'
return resp
return f(request, *args, **kwargs)
return decorated
return wrapper

View File

@@ -0,0 +1,10 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import ResourceProtector, BearerTokenValidator
from .endpoints import RevocationEndpoint
from .signals import (
client_authenticated,
token_authenticated,
token_revoked
)

View File

@@ -0,0 +1,118 @@
from django.http import HttpResponse
from django.utils.module_loading import import_string
from django.conf import settings
from authlib.oauth2 import (
AuthorizationServer as _AuthorizationServer,
)
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from authlib.common.security import generate_token as _generate_token
from authlib.common.encoding import json_dumps
from .requests import DjangoOAuth2Request, DjangoJsonRequest
from .signals import client_authenticated, token_revoked
class AuthorizationServer(_AuthorizationServer):
"""Django implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
Initialize it with client model and token model::
from authlib.integrations.django_oauth2 import AuthorizationServer
from your_project.models import OAuth2Client, OAuth2Token
server = AuthorizationServer(OAuth2Client, OAuth2Token)
"""
def __init__(self, client_model, token_model):
self.config = getattr(settings, 'AUTHLIB_OAUTH2_PROVIDER', {})
self.client_model = client_model
self.token_model = token_model
scopes_supported = self.config.get('scopes_supported')
super().__init__(scopes_supported=scopes_supported)
# add default token generator
self.register_token_generator('default', self.create_bearer_token_generator())
def query_client(self, client_id):
"""Default method for ``AuthorizationServer.query_client``. Developers MAY
rewrite this function to meet their own needs.
"""
try:
return self.client_model.objects.get(client_id=client_id)
except self.client_model.DoesNotExist:
return None
def save_token(self, token, request):
"""Default method for ``AuthorizationServer.save_token``. Developers MAY
rewrite this function to meet their own needs.
"""
client = request.client
if request.user:
user_id = request.user.pk
else:
user_id = client.user_id
item = self.token_model(
client_id=client.client_id,
user_id=user_id,
**token
)
item.save()
return item
def create_oauth2_request(self, request):
return DjangoOAuth2Request(request)
def create_json_request(self, request):
return DjangoJsonRequest(request)
def handle_response(self, status_code, payload, headers):
if isinstance(payload, dict):
payload = json_dumps(payload)
resp = HttpResponse(payload, status=status_code)
for k, v in headers:
resp[k] = v
return resp
def send_signal(self, name, *args, **kwargs):
if name == 'after_authenticate_client':
client_authenticated.send(sender=self.__class__, *args, **kwargs)
elif name == 'after_revoke_token':
token_revoked.send(sender=self.__class__, *args, **kwargs)
def create_bearer_token_generator(self):
"""Default method to create BearerToken generator."""
conf = self.config.get('access_token_generator', True)
access_token_generator = create_token_generator(conf, 42)
conf = self.config.get('refresh_token_generator', False)
refresh_token_generator = create_token_generator(conf, 48)
conf = self.config.get('token_expires_in')
expires_generator = create_token_expires_in_generator(conf)
return BearerTokenGenerator(
access_token_generator=access_token_generator,
refresh_token_generator=refresh_token_generator,
expires_generator=expires_generator,
)
def create_token_generator(token_generator_conf, length=42):
if callable(token_generator_conf):
return token_generator_conf
if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)
elif token_generator_conf is True:
def token_generator(*args, **kwargs):
return _generate_token(length)
return token_generator
def create_token_expires_in_generator(expires_in_conf=None):
data = {}
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
if expires_in_conf:
data.update(expires_in_conf)
def expires_in(client, grant_type):
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
return expires_in

View File

@@ -0,0 +1,56 @@
from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint
class RevocationEndpoint(_RevocationEndpoint):
"""The revocation endpoint for OAuth authorization servers allows clients
to notify the authorization server that a previously obtained refresh or
access token is no longer needed.
Register it into authorization server, and create token endpoint response
for token revocation::
from django.views.decorators.http import require_http_methods
# see register into authorization server instance
server.register_endpoint(RevocationEndpoint)
@require_http_methods(["POST"])
def revoke_token(request):
return server.create_endpoint_response(
RevocationEndpoint.ENDPOINT_NAME,
request
)
"""
def query_token(self, token, token_type_hint):
"""Query requested token from database."""
token_model = self.server.token_model
if token_type_hint == 'access_token':
rv = _query_access_token(token_model, token)
elif token_type_hint == 'refresh_token':
rv = _query_refresh_token(token_model, token)
else:
rv = _query_access_token(token_model, token)
if not rv:
rv = _query_refresh_token(token_model, token)
return rv
def revoke_token(self, token, request):
"""Mark the give token as revoked."""
token.revoked = True
token.save()
def _query_access_token(token_model, token):
try:
return token_model.objects.get(access_token=token)
except token_model.DoesNotExist:
return None
def _query_refresh_token(token_model, token):
try:
return token_model.objects.get(refresh_token=token)
except token_model.DoesNotExist:
return None

View File

@@ -0,0 +1,46 @@
from collections import defaultdict
from django.http import HttpRequest
from django.utils.functional import cached_property
from authlib.common.encoding import json_loads
from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest
class DjangoOAuth2Request(OAuth2Request):
def __init__(self, request: HttpRequest):
super().__init__(request.method, request.build_absolute_uri(), None, request.headers)
self._request = request
@property
def args(self):
return self._request.GET
@property
def form(self):
return self._request.POST
@cached_property
def data(self):
data = {}
data.update(self._request.GET.dict())
data.update(self._request.POST.dict())
return data
@cached_property
def datalist(self):
values = defaultdict(list)
for k in self.args:
values[k].extend(self.args.getlist(k))
for k in self.form:
values[k].extend(self.form.getlist(k))
return values
class DjangoJsonRequest(JsonRequest):
def __init__(self, request: HttpRequest):
super().__init__(request.method, request.build_absolute_uri(), None, request.headers)
self._request = request
@cached_property
def data(self):
return json_loads(self._request.body)

View File

@@ -0,0 +1,75 @@
import functools
from django.http import JsonResponse
from authlib.oauth2 import (
OAuth2Error,
ResourceProtector as _ResourceProtector,
)
from authlib.oauth2.rfc6749 import (
MissingAuthorizationError,
)
from authlib.oauth2.rfc6750 import (
BearerTokenValidator as _BearerTokenValidator
)
from .requests import DjangoJsonRequest
from .signals import token_authenticated
class ResourceProtector(_ResourceProtector):
def acquire_token(self, request, scopes=None, **kwargs):
"""A method to acquire current valid token with the given scope.
:param request: Django HTTP request instance
:param scopes: a list of scope values
:return: token object
"""
req = DjangoJsonRequest(request)
# backward compatibility
kwargs['scopes'] = scopes
for claim in kwargs:
if isinstance(kwargs[claim], str):
kwargs[claim] = [kwargs[claim]]
token = self.validate_request(request=req, **kwargs)
token_authenticated.send(sender=self.__class__, token=token)
return token
def __call__(self, scopes=None, optional=False, **kwargs):
claims = kwargs
# backward compatibility
claims['scopes'] = scopes
def wrapper(f):
@functools.wraps(f)
def decorated(request, *args, **kwargs):
try:
token = self.acquire_token(request, **claims)
request.oauth_token = token
except MissingAuthorizationError as error:
if optional:
request.oauth_token = None
return f(request, *args, **kwargs)
return return_error_response(error)
except OAuth2Error as error:
return return_error_response(error)
return f(request, *args, **kwargs)
return decorated
return wrapper
class BearerTokenValidator(_BearerTokenValidator):
def __init__(self, token_model, realm=None, **extra_attributes):
self.token_model = token_model
super().__init__(realm, **extra_attributes)
def authenticate_token(self, token_string):
try:
return self.token_model.objects.get(access_token=token_string)
except self.token_model.DoesNotExist:
return None
def return_error_response(error):
body = dict(error.get_body())
resp = JsonResponse(body, status=error.status_code)
headers = error.get_headers()
for k, v in headers:
resp[k] = v
return resp

View File

@@ -0,0 +1,11 @@
from django.dispatch import Signal
#: signal when client is authenticated
client_authenticated = Signal()
#: signal when token is revoked
token_revoked = Signal()
#: signal when token is authenticated
token_authenticated = Signal()

View File

@@ -0,0 +1,51 @@
from werkzeug.local import LocalProxy
from .integration import FlaskIntegration, token_update
from .apps import FlaskOAuth1App, FlaskOAuth2App
from ..base_client import BaseOAuth, OAuthError
class OAuth(BaseOAuth):
oauth1_client_cls = FlaskOAuth1App
oauth2_client_cls = FlaskOAuth2App
framework_integration_cls = FlaskIntegration
def __init__(self, app=None, cache=None, fetch_token=None, update_token=None):
super().__init__(
cache=cache, fetch_token=fetch_token, update_token=update_token)
self.app = app
if app:
self.init_app(app)
def init_app(self, app, cache=None, fetch_token=None, update_token=None):
"""Initialize lazy for Flask app. This is usually used for Flask application
factory pattern.
"""
self.app = app
if cache is not None:
self.cache = cache
if fetch_token:
self.fetch_token = fetch_token
if update_token:
self.update_token = update_token
app.extensions = getattr(app, 'extensions', {})
app.extensions['authlib.integrations.flask_client'] = self
def create_client(self, name):
if not self.app:
raise RuntimeError('OAuth is not init with Flask app.')
return super().create_client(name)
def register(self, name, overwrite=False, **kwargs):
self._registry[name] = (overwrite, kwargs)
if self.app:
return self.create_client(name)
return LocalProxy(lambda: self.create_client(name))
__all__ = [
'OAuth', 'FlaskIntegration',
'FlaskOAuth1App', 'FlaskOAuth2App',
'token_update', 'OAuthError',
]

View File

@@ -0,0 +1,107 @@
from flask import g, redirect, request, session
from ..requests_client import OAuth1Session, OAuth2Session
from ..base_client import (
BaseApp, OAuthError,
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
)
class FlaskAppMixin:
@property
def token(self):
attr = f'_oauth_token_{self.name}'
token = g.get(attr)
if token:
return token
if self._fetch_token:
token = self._fetch_token()
self.token = token
return token
@token.setter
def token(self, token):
attr = f'_oauth_token_{self.name}'
setattr(g, attr, token)
def _get_requested_token(self, *args, **kwargs):
return self.token
def save_authorize_data(self, **kwargs):
state = kwargs.pop('state', None)
if state:
self.framework.set_state_data(session, state, kwargs)
else:
raise RuntimeError('Missing state value')
def authorize_redirect(self, redirect_uri=None, **kwargs):
"""Create a HTTP Redirect for Authorization Endpoint.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: A HTTP redirect response.
"""
rv = self.create_authorization_url(redirect_uri, **kwargs)
self.save_authorize_data(redirect_uri=redirect_uri, **rv)
return redirect(rv['url'])
class FlaskOAuth1App(FlaskAppMixin, OAuth1Mixin, BaseApp):
client_cls = OAuth1Session
def authorize_access_token(self, **kwargs):
"""Fetch access token in one step.
:return: A token dict.
"""
params = request.args.to_dict(flat=True)
state = params.get('oauth_token')
if not state:
raise OAuthError(description='Missing "oauth_token" parameter')
data = self.framework.get_state_data(session, state)
if not data:
raise OAuthError(description='Missing "request_token" in temporary data')
params['request_token'] = data['request_token']
params.update(kwargs)
self.framework.clear_state_data(session, state)
token = self.fetch_access_token(**params)
self.token = token
return token
class FlaskOAuth2App(FlaskAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
client_cls = OAuth2Session
def authorize_access_token(self, **kwargs):
"""Fetch access token in one step.
:return: A token dict.
"""
if request.method == 'GET':
error = request.args.get('error')
if error:
description = request.args.get('error_description')
raise OAuthError(error=error, description=description)
params = {
'code': request.args.get('code'),
'state': request.args.get('state'),
}
else:
params = {
'code': request.form.get('code'),
'state': request.form.get('state'),
}
claims_options = kwargs.pop('claims_options', None)
state_data = self.framework.get_state_data(session, params.get('state'))
self.framework.clear_state_data(session, params.get('state'))
params = self._format_state_params(state_data, params)
token = self.fetch_access_token(**params, **kwargs)
self.token = token
if 'id_token' in token and 'nonce' in state_data:
userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
token['userinfo'] = userinfo
return token

View File

@@ -0,0 +1,28 @@
from flask import current_app
from flask.signals import Namespace
from ..base_client import FrameworkIntegration
_signal = Namespace()
#: signal when token is updated
token_update = _signal.signal('token_update')
class FlaskIntegration(FrameworkIntegration):
def update_token(self, token, refresh_token=None, access_token=None):
token_update.send(
current_app,
name=self.name,
token=token,
refresh_token=refresh_token,
access_token=access_token,
)
@staticmethod
def load_config(oauth, name, params):
rv = {}
for k in params:
conf_key = f'{name}_{k}'.upper()
v = oauth.app.config.get(conf_key, None)
if v is not None:
rv[k] = v
return rv

View File

@@ -0,0 +1,9 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import ResourceProtector, current_credential
from .cache import (
register_nonce_hooks,
register_temporary_credential_hooks,
create_exists_nonce_func,
)

View File

@@ -0,0 +1,182 @@
import logging
from werkzeug.utils import import_string
from flask import Response
from flask import request as flask_req
from authlib.oauth1 import (
OAuth1Request,
AuthorizationServer as _AuthorizationServer,
)
from authlib.common.security import generate_token
from authlib.common.urls import url_encode
log = logging.getLogger(__name__)
class AuthorizationServer(_AuthorizationServer):
"""Flask implementation of :class:`authlib.rfc5849.AuthorizationServer`.
Initialize it with Flask app instance, client model class and cache::
server = AuthorizationServer(app=app, query_client=query_client)
# or initialize lazily
server = AuthorizationServer()
server.init_app(app, query_client=query_client)
:param app: A Flask app instance
:param query_client: A function to get client by client_id. The client
model class MUST implement the methods described by
:class:`~authlib.oauth1.rfc5849.ClientMixin`.
:param token_generator: A function to generate token
"""
def __init__(self, app=None, query_client=None, token_generator=None):
self.app = app
self.query_client = query_client
self.token_generator = token_generator
self._hooks = {
'exists_nonce': None,
'create_temporary_credential': None,
'get_temporary_credential': None,
'delete_temporary_credential': None,
'create_authorization_verifier': None,
'create_token_credential': None,
}
if app is not None:
self.init_app(app)
def init_app(self, app, query_client=None, token_generator=None):
if query_client is not None:
self.query_client = query_client
if token_generator is not None:
self.token_generator = token_generator
if self.token_generator is None:
self.token_generator = self.create_token_generator(app)
methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS')
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self.app = app
def register_hook(self, name, func):
if name not in self._hooks:
raise ValueError('Invalid "name" of hook')
self._hooks[name] = func
def create_token_generator(self, app):
token_generator = app.config.get('OAUTH1_TOKEN_GENERATOR')
if isinstance(token_generator, str):
token_generator = import_string(token_generator)
else:
length = app.config.get('OAUTH1_TOKEN_LENGTH', 42)
def token_generator():
return generate_token(length)
secret_generator = app.config.get('OAUTH1_TOKEN_SECRET_GENERATOR')
if isinstance(secret_generator, str):
secret_generator = import_string(secret_generator)
else:
length = app.config.get('OAUTH1_TOKEN_SECRET_LENGTH', 48)
def secret_generator():
return generate_token(length)
def create_token():
return {
'oauth_token': token_generator(),
'oauth_token_secret': secret_generator()
}
return create_token
def get_client_by_id(self, client_id):
return self.query_client(client_id)
def exists_nonce(self, nonce, request):
func = self._hooks['exists_nonce']
if callable(func):
timestamp = request.timestamp
client_id = request.client_id
token = request.token
return func(nonce, timestamp, client_id, token)
raise RuntimeError('"exists_nonce" hook is required.')
def create_temporary_credential(self, request):
func = self._hooks['create_temporary_credential']
if callable(func):
token = self.token_generator()
return func(token, request.client_id, request.redirect_uri)
raise RuntimeError(
'"create_temporary_credential" hook is required.'
)
def get_temporary_credential(self, request):
func = self._hooks['get_temporary_credential']
if callable(func):
return func(request.token)
raise RuntimeError(
'"get_temporary_credential" hook is required.'
)
def delete_temporary_credential(self, request):
func = self._hooks['delete_temporary_credential']
if callable(func):
return func(request.token)
raise RuntimeError(
'"delete_temporary_credential" hook is required.'
)
def create_authorization_verifier(self, request):
func = self._hooks['create_authorization_verifier']
if callable(func):
verifier = generate_token(36)
func(request.credential, request.user, verifier)
return verifier
raise RuntimeError(
'"create_authorization_verifier" hook is required.'
)
def create_token_credential(self, request):
func = self._hooks['create_token_credential']
if callable(func):
temporary_credential = request.credential
token = self.token_generator()
return func(token, temporary_credential)
raise RuntimeError(
'"create_token_credential" hook is required.'
)
def check_authorization_request(self):
req = self.create_oauth1_request(None)
self.validate_authorization_request(req)
return req
def create_authorization_response(self, request=None, grant_user=None):
return super()\
.create_authorization_response(request, grant_user)
def create_token_response(self, request=None):
return super().create_token_response(request)
def create_oauth1_request(self, request):
if request is None:
request = flask_req
if request.method in ('POST', 'PUT'):
body = request.form.to_dict(flat=True)
else:
body = None
return OAuth1Request(request.method, request.url, body, request.headers)
def handle_response(self, status_code, payload, headers):
return Response(
url_encode(payload),
status=status_code,
headers=headers
)

View File

@@ -0,0 +1,80 @@
from authlib.oauth1 import TemporaryCredential
def register_temporary_credential_hooks(
authorization_server, cache, key_prefix='temporary_credential:'):
"""Register temporary credential related hooks to authorization server.
:param authorization_server: AuthorizationServer instance
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
"""
def create_temporary_credential(token, client_id, redirect_uri):
key = key_prefix + token['oauth_token']
token['client_id'] = client_id
if redirect_uri:
token['oauth_callback'] = redirect_uri
cache.set(key, token, timeout=86400) # cache for one day
return TemporaryCredential(token)
def get_temporary_credential(oauth_token):
if not oauth_token:
return None
key = key_prefix + oauth_token
value = cache.get(key)
if value:
return TemporaryCredential(value)
def delete_temporary_credential(oauth_token):
if oauth_token:
key = key_prefix + oauth_token
cache.delete(key)
def create_authorization_verifier(credential, grant_user, verifier):
key = key_prefix + credential.get_oauth_token()
credential['oauth_verifier'] = verifier
credential['user_id'] = grant_user.get_user_id()
cache.set(key, credential, timeout=86400)
return credential
authorization_server.register_hook(
'create_temporary_credential', create_temporary_credential)
authorization_server.register_hook(
'get_temporary_credential', get_temporary_credential)
authorization_server.register_hook(
'delete_temporary_credential', delete_temporary_credential)
authorization_server.register_hook(
'create_authorization_verifier', create_authorization_verifier)
def create_exists_nonce_func(cache, key_prefix='nonce:', expires=86400):
"""Create an ``exists_nonce`` function that can be used in hooks and
resource protector.
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
:param expires: Expire time for nonce
"""
def exists_nonce(nonce, timestamp, client_id, oauth_token):
key = f'{key_prefix}{nonce}-{timestamp}-{client_id}'
if oauth_token:
key = f'{key}-{oauth_token}'
rv = cache.has(key)
cache.set(key, 1, timeout=expires)
return rv
return exists_nonce
def register_nonce_hooks(
authorization_server, cache, key_prefix='nonce:', expires=86400):
"""Register nonce related hooks to authorization server.
:param authorization_server: AuthorizationServer instance
:param cache: Cache instance
:param key_prefix: key prefix for temporary credential
:param expires: Expire time for nonce
"""
exists_nonce = create_exists_nonce_func(cache, key_prefix, expires)
authorization_server.register_hook('exists_nonce', exists_nonce)

View File

@@ -0,0 +1,113 @@
import functools
from flask import g, json, Response
from flask import request as _req
from werkzeug.local import LocalProxy
from authlib.consts import default_json_headers
from authlib.oauth1 import ResourceProtector as _ResourceProtector
from authlib.oauth1.errors import OAuth1Error
class ResourceProtector(_ResourceProtector):
"""A protecting method for resource servers. Initialize a resource
protector with the these method:
1. query_client
2. query_token,
3. exists_nonce
Usually, a ``query_client`` method would look like (if using SQLAlchemy)::
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
A ``query_token`` method accept two parameters, ``client_id`` and ``oauth_token``::
def query_token(client_id, oauth_token):
return Token.query.filter_by(client_id=client_id, oauth_token=oauth_token).first()
And for ``exists_nonce``, if using cache, we have a built-in hook to create this method::
from authlib.integrations.flask_oauth1 import create_exists_nonce_func
exists_nonce = create_exists_nonce_func(cache)
Then initialize the resource protector with those methods::
require_oauth = ResourceProtector(
app, query_client=query_client,
query_token=query_token, exists_nonce=exists_nonce,
)
"""
def __init__(self, app=None, query_client=None,
query_token=None, exists_nonce=None):
self.query_client = query_client
self.query_token = query_token
self._exists_nonce = exists_nonce
self.app = app
if app:
self.init_app(app)
def init_app(self, app, query_client=None, query_token=None,
exists_nonce=None):
if query_client is not None:
self.query_client = query_client
if query_token is not None:
self.query_token = query_token
if exists_nonce is not None:
self._exists_nonce = exists_nonce
methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS')
if methods and isinstance(methods, (list, tuple)):
self.SUPPORTED_SIGNATURE_METHODS = methods
self.app = app
def get_client_by_id(self, client_id):
return self.query_client(client_id)
def get_token_credential(self, request):
return self.query_token(request.client_id, request.token)
def exists_nonce(self, nonce, request):
if not self._exists_nonce:
raise RuntimeError('"exists_nonce" function is required.')
timestamp = request.timestamp
client_id = request.client_id
token = request.token
return self._exists_nonce(nonce, timestamp, client_id, token)
def acquire_credential(self):
req = self.validate_request(
_req.method,
_req.url,
_req.form.to_dict(flat=True),
_req.headers
)
g.authlib_server_oauth1_credential = req.credential
return req.credential
def __call__(self, scope=None):
def wrapper(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
self.acquire_credential()
except OAuth1Error as error:
body = dict(error.get_body())
return Response(
json.dumps(body),
status=error.status_code,
headers=default_json_headers,
)
return f(*args, **kwargs)
return decorated
return wrapper
def _get_current_credential():
return g.get('authlib_server_oauth1_credential')
current_credential = LocalProxy(_get_current_credential)

View File

@@ -0,0 +1,12 @@
# flake8: noqa
from .authorization_server import AuthorizationServer
from .resource_protector import (
ResourceProtector,
current_token,
)
from .signals import (
client_authenticated,
token_authenticated,
token_revoked,
)

View File

@@ -0,0 +1,159 @@
from werkzeug.utils import import_string
from flask import Response, json
from flask import request as flask_req
from authlib.oauth2 import (
AuthorizationServer as _AuthorizationServer,
)
from authlib.oauth2.rfc6750 import BearerTokenGenerator
from authlib.common.security import generate_token
from .requests import FlaskOAuth2Request, FlaskJsonRequest
from .signals import client_authenticated, token_revoked
class AuthorizationServer(_AuthorizationServer):
"""Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
Initialize it with ``query_client``, ``save_token`` methods and Flask
app instance::
def query_client(client_id):
return Client.query.filter_by(client_id=client_id).first()
def save_token(token, request):
if request.user:
user_id = request.user.id
else:
user_id = None
client = request.client
tok = Token(
client_id=client.client_id,
user_id=user.id,
**token
)
db.session.add(tok)
db.session.commit()
server = AuthorizationServer(app, query_client, save_token)
# or initialize lazily
server = AuthorizationServer()
server.init_app(app, query_client, save_token)
"""
def __init__(self, app=None, query_client=None, save_token=None):
super().__init__()
self._query_client = query_client
self._save_token = save_token
self._error_uris = None
if app is not None:
self.init_app(app)
def init_app(self, app, query_client=None, save_token=None):
"""Initialize later with Flask app instance."""
if query_client is not None:
self._query_client = query_client
if save_token is not None:
self._save_token = save_token
self.register_token_generator('default', self.create_bearer_token_generator(app.config))
self.scopes_supported = app.config.get('OAUTH2_SCOPES_SUPPORTED')
self._error_uris = app.config.get('OAUTH2_ERROR_URIS')
def query_client(self, client_id):
return self._query_client(client_id)
def save_token(self, token, request):
return self._save_token(token, request)
def get_error_uri(self, request, error):
if self._error_uris:
uris = dict(self._error_uris)
return uris.get(error.error)
def create_oauth2_request(self, request):
return FlaskOAuth2Request(flask_req)
def create_json_request(self, request):
return FlaskJsonRequest(flask_req)
def handle_response(self, status_code, payload, headers):
if isinstance(payload, dict):
payload = json.dumps(payload)
return Response(payload, status=status_code, headers=headers)
def send_signal(self, name, *args, **kwargs):
if name == 'after_authenticate_client':
client_authenticated.send(self, *args, **kwargs)
elif name == 'after_revoke_token':
token_revoked.send(self, *args, **kwargs)
def create_bearer_token_generator(self, config):
"""Create a generator function for generating ``token`` value. This
method will create a Bearer Token generator with
:class:`authlib.oauth2.rfc6750.BearerToken`.
Configurable settings:
1. OAUTH2_ACCESS_TOKEN_GENERATOR: Boolean or import string, default is True.
2. OAUTH2_REFRESH_TOKEN_GENERATOR: Boolean or import string, default is False.
3. OAUTH2_TOKEN_EXPIRES_IN: Dict or import string, default is None.
By default, it will not generate ``refresh_token``, which can be turn on by
configure ``OAUTH2_REFRESH_TOKEN_GENERATOR``.
Here are some examples of the token generator::
OAUTH2_ACCESS_TOKEN_GENERATOR = 'your_project.generators.gen_token'
# and in module `your_project.generators`, you can define:
def gen_token(client, grant_type, user, scope):
# generate token according to these parameters
token = create_random_token()
return f'{client.id}-{user.id}-{token}'
Here is an example of ``OAUTH2_TOKEN_EXPIRES_IN``::
OAUTH2_TOKEN_EXPIRES_IN = {
'authorization_code': 864000,
'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600,
}
"""
conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True)
access_token_generator = create_token_generator(conf, 42)
conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False)
refresh_token_generator = create_token_generator(conf, 48)
expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN')
expires_generator = create_token_expires_in_generator(expires_conf)
return BearerTokenGenerator(
access_token_generator,
refresh_token_generator,
expires_generator
)
def create_token_expires_in_generator(expires_in_conf=None):
if isinstance(expires_in_conf, str):
return import_string(expires_in_conf)
data = {}
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
if isinstance(expires_in_conf, dict):
data.update(expires_in_conf)
def expires_in(client, grant_type):
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
return expires_in
def create_token_generator(token_generator_conf, length=42):
if callable(token_generator_conf):
return token_generator_conf
if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)
elif token_generator_conf is True:
def token_generator(*args, **kwargs):
return generate_token(length)
return token_generator

View File

@@ -0,0 +1,39 @@
import importlib.metadata
import werkzeug
from werkzeug.exceptions import HTTPException
_version = importlib.metadata.version('werkzeug').split('.')[0]
if _version in ('0', '1'):
class _HTTPException(HTTPException):
def __init__(self, code, body, headers, response=None):
super().__init__(None, response)
self.code = code
self.body = body
self.headers = headers
def get_body(self, environ=None):
return self.body
def get_headers(self, environ=None):
return self.headers
else:
class _HTTPException(HTTPException):
def __init__(self, code, body, headers, response=None):
super().__init__(None, response)
self.code = code
self.body = body
self.headers = headers
def get_body(self, environ=None, scope=None):
return self.body
def get_headers(self, environ=None, scope=None):
return self.headers
def raise_http_exception(status, body, headers):
raise _HTTPException(status, body, headers)

View File

@@ -0,0 +1,40 @@
from collections import defaultdict
from functools import cached_property
from flask.wrappers import Request
from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest
class FlaskOAuth2Request(OAuth2Request):
def __init__(self, request: Request):
super().__init__(request.method, request.url, None, request.headers)
self._request = request
@property
def args(self):
return self._request.args
@property
def form(self):
return self._request.form
@property
def data(self):
return self._request.values
@cached_property
def datalist(self):
values = defaultdict(list)
for k in self.data:
values[k].extend(self.data.getlist(k))
return values
class FlaskJsonRequest(JsonRequest):
def __init__(self, request: Request):
super().__init__(request.method, request.url, None, request.headers)
self._request = request
@property
def data(self):
return self._request.get_json()

View File

@@ -0,0 +1,114 @@
import functools
from contextlib import contextmanager
from flask import g, json
from flask import request as _req
from werkzeug.local import LocalProxy
from authlib.oauth2 import (
OAuth2Error,
ResourceProtector as _ResourceProtector
)
from authlib.oauth2.rfc6749 import (
MissingAuthorizationError,
)
from .requests import FlaskJsonRequest
from .signals import token_authenticated
from .errors import raise_http_exception
class ResourceProtector(_ResourceProtector):
"""A protecting method for resource servers. Creating a ``require_oauth``
decorator easily with ResourceProtector::
from authlib.integrations.flask_oauth2 import ResourceProtector
require_oauth = ResourceProtector()
# add bearer token validator
from authlib.oauth2.rfc6750 import BearerTokenValidator
from project.models import Token
class MyBearerTokenValidator(BearerTokenValidator):
def authenticate_token(self, token_string):
return Token.query.filter_by(access_token=token_string).first()
require_oauth.register_token_validator(MyBearerTokenValidator())
# protect resource with require_oauth
@app.route('/user')
@require_oauth(['profile'])
def user_profile():
user = User.get(current_token.user_id)
return jsonify(user.to_dict())
"""
def raise_error_response(self, error):
"""Raise HTTPException for OAuth2Error. Developers can re-implement
this method to customize the error response.
:param error: OAuth2Error
:raise: HTTPException
"""
status = error.status_code
body = json.dumps(dict(error.get_body()))
headers = error.get_headers()
raise_http_exception(status, body, headers)
def acquire_token(self, scopes=None, **kwargs):
"""A method to acquire current valid token with the given scope.
:param scopes: a list of scope values
:return: token object
"""
request = FlaskJsonRequest(_req)
# backward compatibility
kwargs['scopes'] = scopes
for claim in kwargs:
if isinstance(kwargs[claim], str):
kwargs[claim] = [kwargs[claim]]
token = self.validate_request(request=request, **kwargs)
token_authenticated.send(self, token=token)
g.authlib_server_oauth2_token = token
return token
@contextmanager
def acquire(self, scopes=None):
"""The with statement of ``require_oauth``. Instead of using a
decorator, you can use a with statement instead::
@app.route('/api/user')
def user_api():
with require_oauth.acquire('profile') as token:
user = User.get(token.user_id)
return jsonify(user.to_dict())
"""
try:
yield self.acquire_token(scopes)
except OAuth2Error as error:
self.raise_error_response(error)
def __call__(self, scopes=None, optional=False, **kwargs):
claims = kwargs
# backward compatibility
claims['scopes'] = scopes
def wrapper(f):
@functools.wraps(f)
def decorated(*args, **kwargs):
try:
self.acquire_token(**claims)
except MissingAuthorizationError as error:
if optional:
return f(*args, **kwargs)
self.raise_error_response(error)
except OAuth2Error as error:
self.raise_error_response(error)
return f(*args, **kwargs)
return decorated
return wrapper
def _get_current_token():
return g.get('authlib_server_oauth2_token')
current_token = LocalProxy(_get_current_token)

View File

@@ -0,0 +1,12 @@
from flask.signals import Namespace
_signal = Namespace()
#: signal when client is authenticated
client_authenticated = _signal.signal('client_authenticated')
#: signal when token is revoked
token_revoked = _signal.signal('token_revoked')
#: signal when token is authenticated
token_authenticated = _signal.signal('token_authenticated')

View File

@@ -0,0 +1,25 @@
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_RSA_SHA1,
SIGNATURE_PLAINTEXT,
SIGNATURE_TYPE_HEADER,
SIGNATURE_TYPE_QUERY,
SIGNATURE_TYPE_BODY,
)
from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client
from .oauth2_client import (
OAuth2Auth, OAuth2Client, OAuth2ClientAuth,
AsyncOAuth2Client,
)
from .assertion_client import AssertionClient, AsyncAssertionClient
from ..base_client import OAuthError
__all__ = [
'OAuthError',
'OAuth1Auth', 'AsyncOAuth1Client',
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
'OAuth2Auth', 'OAuth2ClientAuth', 'OAuth2Client', 'AsyncOAuth2Client',
'AssertionClient', 'AsyncAssertionClient',
]

View File

@@ -0,0 +1,81 @@
import httpx
from httpx import Response, USE_CLIENT_DEFAULT
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
from authlib.oauth2.rfc7523 import JWTBearerGrant
from .utils import extract_client_kwargs
from .oauth2_client import OAuth2Auth
from ..base_client import OAuthError
__all__ = ['AsyncAssertionClient']
class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self, session=None,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
)
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs) -> Response:
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
await self.refresh_token()
auth = self.token_auth
return await super().request(
method, url, auth=auth, **kwargs)
async def _refresh_token(self, data):
resp = await self.request(
'POST', self.token_endpoint, data=data, withhold_token=True)
return self.parse_response_token(resp)
class AssertionClient(_AssertionClient, httpx.Client):
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
ASSERTION_METHODS = {
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
}
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
claims=None, token_placement='header', scope=None, **kwargs):
client_kwargs = extract_client_kwargs(kwargs)
httpx.Client.__init__(self, **client_kwargs)
_AssertionClient.__init__(
self, session=self,
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
audience=audience, grant_type=grant_type, claims=claims,
token_placement=token_placement, scope=scope, **kwargs
)
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
"""Send request with auto refresh token feature."""
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token or self.token.is_expired():
self.refresh_token()
auth = self.token_auth
return super().request(
method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,102 @@
import typing
import httpx
from httpx import Auth, Request, Response
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_TYPE_HEADER,
)
from authlib.common.encoding import to_unicode
from authlib.oauth1 import ClientAuth
from authlib.oauth1.client import OAuth1Client as _OAuth1Client
from .utils import build_request, extract_client_kwargs
from ..base_client import OAuthError
class OAuth1Auth(Auth, ClientAuth):
"""Signs the httpx request using OAuth 1 (RFC5849)"""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield build_request(url=url, headers=headers, body=body, initial_request=request)
class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient):
auth_class = OAuth1Auth
def __init__(self, client_id, client_secret=None,
token=None, token_secret=None,
redirect_uri=None, rsa_key=None, verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False, **kwargs):
_client_kwargs = extract_client_kwargs(kwargs)
httpx.AsyncClient.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self, None,
client_id=client_id, client_secret=client_secret,
token=token, token_secret=token_secret,
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
signature_method=signature_method, signature_type=signature_type,
force_include_body=force_include_body, **kwargs)
async def fetch_access_token(self, url, verifier=None, **kwargs):
"""Method for fetching an access token from the token endpoint.
This is the final step in the OAuth 1 workflow. An access token is
obtained using all previously obtained credentials, including the
verifier from the authorization step.
:param url: Access Token endpoint.
:param verifier: A verifier string to prove authorization was granted.
:param kwargs: Extra parameters to include for fetching access token.
:return: A token dict.
"""
if verifier:
self.auth.verifier = verifier
if not self.auth.verifier:
self.handle_error('missing_verifier', 'Missing "verifier" value')
token = await self._fetch_token(url, **kwargs)
self.auth.verifier = None
return token
async def _fetch_token(self, url, **kwargs):
resp = await self.post(url, **kwargs)
text = await resp.aread()
token = self.parse_response_token(resp.status_code, to_unicode(text))
self.token = token
return token
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
class OAuth1Client(_OAuth1Client, httpx.Client):
auth_class = OAuth1Auth
def __init__(self, client_id, client_secret=None,
token=None, token_secret=None,
redirect_uri=None, rsa_key=None, verifier=None,
signature_method=SIGNATURE_HMAC_SHA1,
signature_type=SIGNATURE_TYPE_HEADER,
force_include_body=False, **kwargs):
_client_kwargs = extract_client_kwargs(kwargs)
httpx.Client.__init__(self, **_client_kwargs)
_OAuth1Client.__init__(
self, self,
client_id=client_id, client_secret=client_secret,
token=token, token_secret=token_secret,
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
signature_method=signature_method, signature_type=signature_type,
force_include_body=force_include_body, **kwargs)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)

View File

@@ -0,0 +1,220 @@
import typing
from contextlib import asynccontextmanager
import httpx
from httpx import Auth, Request, Response, USE_CLIENT_DEFAULT
from anyio import Lock # Import after httpx so import errors refer to httpx
from authlib.common.urls import url_decode
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
from authlib.oauth2.auth import ClientAuth, TokenAuth
from .utils import HTTPX_CLIENT_KWARGS, build_request
from ..base_client import (
OAuthError,
InvalidTokenError,
MissingTokenError,
UnsupportedTokenTypeError,
)
__all__ = [
'OAuth2Auth', 'OAuth2ClientAuth',
'AsyncOAuth2Client', 'OAuth2Client',
]
class OAuth2Auth(Auth, TokenAuth):
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
try:
url, headers, body = self.prepare(
str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield build_request(url=url, headers=headers, body=body, initial_request=request)
except KeyError as error:
description = f'Unsupported token_type: {str(error)}'
raise UnsupportedTokenTypeError(description=description)
class OAuth2ClientAuth(Auth, ClientAuth):
requires_request_body = True
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
url, headers, body = self.prepare(
request.method, str(request.url), request.headers, request.content)
headers['Content-Length'] = str(len(body))
yield build_request(url=url, headers=headers, body=body, initial_request=request)
class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(self, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, leeway=60, **kwargs):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
httpx.AsyncClient.__init__(self, **client_kwargs)
# We use a Lock to synchronize coroutines to prevent
# multiple concurrent attempts to refresh the same token
self._token_refresh_lock = Lock()
_OAuth2Client.__init__(
self, session=None,
client_id=client_id, client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, leeway=leeway, **kwargs
)
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
return await super().request(
method, url, auth=auth, **kwargs)
@asynccontextmanager
async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
await self.ensure_active_token(self.token)
auth = self.token_auth
async with super().stream(
method, url, auth=auth, **kwargs) as resp:
yield resp
async def ensure_active_token(self, token):
async with self._token_refresh_lock:
if self.token.is_expired(leeway=self.leeway):
refresh_token = token.get('refresh_token')
url = self.metadata.get('token_endpoint')
if refresh_token and url:
await self.refresh_token(url, refresh_token=refresh_token)
elif self.metadata.get('grant_type') == 'client_credentials':
access_token = token['access_token']
new_token = await self.fetch_token(url, grant_type='client_credentials')
if self.update_token:
await self.update_token(new_token, access_token=access_token)
else:
raise InvalidTokenError()
async def _fetch_token(self, url, body='', headers=None, auth=USE_CLIENT_DEFAULT,
method='POST', **kwargs):
if method.upper() == 'POST':
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers,
auth=auth, **kwargs)
else:
if '?' in url:
url = '&'.join([url, body])
else:
url = '?'.join([url, body])
resp = await self.get(url, headers=headers, auth=auth, **kwargs)
for hook in self.compliance_hook['access_token_response']:
resp = hook(resp)
return self.parse_response_token(resp)
async def _refresh_token(self, url, refresh_token=None, body='',
headers=None, auth=USE_CLIENT_DEFAULT, **kwargs):
resp = await self.post(
url, data=dict(url_decode(body)), headers=headers,
auth=auth, **kwargs)
for hook in self.compliance_hook['refresh_token_response']:
resp = hook(resp)
token = self.parse_response_token(resp)
if 'refresh_token' not in token:
self.token['refresh_token'] = refresh_token
if self.update_token:
await self.update_token(self.token, refresh_token=refresh_token)
return self.token
def _http_post(self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs):
return self.post(
url, data=dict(url_decode(body)),
headers=headers, auth=auth, **kwargs)
class OAuth2Client(_OAuth2Client, httpx.Client):
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
client_auth_class = OAuth2ClientAuth
token_auth_class = OAuth2Auth
oauth_error_class = OAuthError
def __init__(self, client_id=None, client_secret=None,
token_endpoint_auth_method=None,
revocation_endpoint_auth_method=None,
scope=None, redirect_uri=None,
token=None, token_placement='header',
update_token=None, **kwargs):
# extract httpx.Client kwargs
client_kwargs = self._extract_session_request_params(kwargs)
httpx.Client.__init__(self, **client_kwargs)
_OAuth2Client.__init__(
self, session=self,
client_id=client_id, client_secret=client_secret,
token_endpoint_auth_method=token_endpoint_auth_method,
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
scope=scope, redirect_uri=redirect_uri,
token=token, token_placement=token_placement,
update_token=update_token, **kwargs
)
@staticmethod
def handle_error(error_type, error_description):
raise OAuthError(error_type, error_description)
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().request(
method, url, auth=auth, **kwargs)
def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
if not withhold_token and auth is USE_CLIENT_DEFAULT:
if not self.token:
raise MissingTokenError()
if not self.ensure_active_token(self.token):
raise InvalidTokenError()
auth = self.token_auth
return super().stream(
method, url, auth=auth, **kwargs)

View File

@@ -0,0 +1,30 @@
from httpx import Request
HTTPX_CLIENT_KWARGS = [
'headers', 'cookies', 'verify', 'cert', 'http1', 'http2',
'proxies', 'timeout', 'follow_redirects', 'limits', 'max_redirects',
'event_hooks', 'base_url', 'transport', 'app', 'trust_env',
]
def extract_client_kwargs(kwargs):
client_kwargs = {}
for k in HTTPX_CLIENT_KWARGS:
if k in kwargs:
client_kwargs[k] = kwargs.pop(k)
return client_kwargs
def build_request(url, headers, body, initial_request: Request) -> Request:
"""Make sure that all the data from initial request is passed to the updated object"""
updated_request = Request(
method=initial_request.method,
url=url,
headers=headers,
content=body
)
if hasattr(initial_request, 'extensions'):
updated_request.extensions = initial_request.extensions
return updated_request

View File

@@ -0,0 +1,22 @@
from .oauth1_session import OAuth1Session, OAuth1Auth
from .oauth2_session import OAuth2Session, OAuth2Auth
from .assertion_session import AssertionSession
from ..base_client import OAuthError
from authlib.oauth1 import (
SIGNATURE_HMAC_SHA1,
SIGNATURE_RSA_SHA1,
SIGNATURE_PLAINTEXT,
SIGNATURE_TYPE_HEADER,
SIGNATURE_TYPE_QUERY,
SIGNATURE_TYPE_BODY,
)
__all__ = [
'OAuthError',
'OAuth1Session', 'OAuth1Auth',
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
'OAuth2Session', 'OAuth2Auth',
'AssertionSession',
]

Some files were not shown because too many files have changed in this diff Show More