venv added, updated
This commit is contained in:
17
myenv/lib/python3.12/site-packages/authlib/__init__.py
Normal file
17
myenv/lib/python3.12/site-packages/authlib/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
authlib
|
||||
~~~~~~~
|
||||
|
||||
The ultimate Python library in building OAuth 1.0, OAuth 2.0 and OpenID
|
||||
Connect clients and providers. It covers from low level specification
|
||||
implementation to high level framework integrations.
|
||||
|
||||
:copyright: (c) 2017 by Hsiaoming Yang.
|
||||
:license: BSD, see LICENSE for more details.
|
||||
"""
|
||||
from .consts import version, homepage, author
|
||||
|
||||
__version__ = version
|
||||
__homepage__ = homepage
|
||||
__author__ = author
|
||||
__license__ = 'BSD-3-Clause'
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,66 @@
|
||||
import json
|
||||
import base64
|
||||
import struct
|
||||
|
||||
|
||||
def to_bytes(x, charset='utf-8', errors='strict'):
|
||||
if x is None:
|
||||
return None
|
||||
if isinstance(x, bytes):
|
||||
return x
|
||||
if isinstance(x, str):
|
||||
return x.encode(charset, errors)
|
||||
if isinstance(x, (int, float)):
|
||||
return str(x).encode(charset, errors)
|
||||
return bytes(x)
|
||||
|
||||
|
||||
def to_unicode(x, charset='utf-8', errors='strict'):
|
||||
if x is None or isinstance(x, str):
|
||||
return x
|
||||
if isinstance(x, bytes):
|
||||
return x.decode(charset, errors)
|
||||
return str(x)
|
||||
|
||||
|
||||
def to_native(x, encoding='ascii'):
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
return x.decode(encoding)
|
||||
|
||||
|
||||
def json_loads(s):
|
||||
return json.loads(s)
|
||||
|
||||
|
||||
def json_dumps(data, ensure_ascii=False):
|
||||
return json.dumps(data, ensure_ascii=ensure_ascii, separators=(',', ':'))
|
||||
|
||||
|
||||
def urlsafe_b64decode(s):
|
||||
s += b'=' * (-len(s) % 4)
|
||||
return base64.urlsafe_b64decode(s)
|
||||
|
||||
|
||||
def urlsafe_b64encode(s):
|
||||
return base64.urlsafe_b64encode(s).rstrip(b'=')
|
||||
|
||||
|
||||
def base64_to_int(s):
|
||||
data = urlsafe_b64decode(to_bytes(s, charset='ascii'))
|
||||
buf = struct.unpack('%sB' % len(data), data)
|
||||
return int(''.join(["%02x" % byte for byte in buf]), 16)
|
||||
|
||||
|
||||
def int_to_base64(num):
|
||||
if num < 0:
|
||||
raise ValueError('Must be a positive integer')
|
||||
|
||||
s = num.to_bytes((num.bit_length() + 7) // 8, 'big', signed=False)
|
||||
return to_unicode(urlsafe_b64encode(s))
|
||||
|
||||
|
||||
def json_b64encode(text):
|
||||
if isinstance(text, dict):
|
||||
text = json_dumps(text)
|
||||
return urlsafe_b64encode(to_bytes(text))
|
||||
63
myenv/lib/python3.12/site-packages/authlib/common/errors.py
Normal file
63
myenv/lib/python3.12/site-packages/authlib/common/errors.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from authlib.consts import default_json_headers
|
||||
|
||||
|
||||
class AuthlibBaseError(Exception):
|
||||
"""Base Exception for all errors in Authlib."""
|
||||
|
||||
#: short-string error code
|
||||
error = None
|
||||
#: long-string to describe this error
|
||||
description = ''
|
||||
#: web page that describes this error
|
||||
uri = None
|
||||
|
||||
def __init__(self, error=None, description=None, uri=None):
|
||||
if error is not None:
|
||||
self.error = error
|
||||
if description is not None:
|
||||
self.description = description
|
||||
if uri is not None:
|
||||
self.uri = uri
|
||||
|
||||
message = f'{self.error}: {self.description}'
|
||||
super().__init__(message)
|
||||
|
||||
def __repr__(self):
|
||||
return f'<{self.__class__.__name__} "{self.error}">'
|
||||
|
||||
|
||||
class AuthlibHTTPError(AuthlibBaseError):
|
||||
#: HTTP status code
|
||||
status_code = 400
|
||||
|
||||
def __init__(self, error=None, description=None, uri=None,
|
||||
status_code=None):
|
||||
super().__init__(error, description, uri)
|
||||
if status_code is not None:
|
||||
self.status_code = status_code
|
||||
|
||||
def get_error_description(self):
|
||||
return self.description
|
||||
|
||||
def get_body(self):
|
||||
error = [('error', self.error)]
|
||||
|
||||
if self.description:
|
||||
error.append(('error_description', self.description))
|
||||
|
||||
if self.uri:
|
||||
error.append(('error_uri', self.uri))
|
||||
return error
|
||||
|
||||
def get_headers(self):
|
||||
return default_json_headers[:]
|
||||
|
||||
def __call__(self, uri=None):
|
||||
self.uri = uri
|
||||
body = dict(self.get_body())
|
||||
headers = self.get_headers()
|
||||
return self.status_code, body, headers
|
||||
|
||||
|
||||
class ContinueIteration(AuthlibBaseError):
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
import string
|
||||
import random
|
||||
|
||||
UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def generate_token(length=30, chars=UNICODE_ASCII_CHARACTER_SET):
|
||||
rand = random.SystemRandom()
|
||||
return ''.join(rand.choice(chars) for _ in range(length))
|
||||
|
||||
|
||||
def is_secure_transport(uri):
|
||||
"""Check if the uri is over ssl."""
|
||||
if os.getenv('AUTHLIB_INSECURE_TRANSPORT'):
|
||||
return True
|
||||
|
||||
uri = uri.lower()
|
||||
return uri.startswith(('https://', 'http://localhost:'))
|
||||
146
myenv/lib/python3.12/site-packages/authlib/common/urls.py
Normal file
146
myenv/lib/python3.12/site-packages/authlib/common/urls.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
authlib.util.urls
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Wrapper functions for URL encoding and decoding.
|
||||
"""
|
||||
|
||||
import re
|
||||
from urllib.parse import quote as _quote
|
||||
from urllib.parse import unquote as _unquote
|
||||
from urllib.parse import urlencode as _urlencode
|
||||
import urllib.parse as urlparse
|
||||
|
||||
from .encoding import to_unicode, to_bytes
|
||||
|
||||
always_safe = (
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
|
||||
'abcdefghijklmnopqrstuvwxyz'
|
||||
'0123456789_.-'
|
||||
)
|
||||
urlencoded = set(always_safe) | set('=&;:%+~,*@!()/?')
|
||||
INVALID_HEX_PATTERN = re.compile(r'%[^0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]')
|
||||
|
||||
|
||||
def url_encode(params):
|
||||
encoded = []
|
||||
for k, v in params:
|
||||
encoded.append((to_bytes(k), to_bytes(v)))
|
||||
return to_unicode(_urlencode(encoded))
|
||||
|
||||
|
||||
def url_decode(query):
|
||||
"""Decode a query string in x-www-form-urlencoded format into a sequence
|
||||
of two-element tuples.
|
||||
|
||||
Unlike urlparse.parse_qsl(..., strict_parsing=True) urldecode will enforce
|
||||
correct formatting of the query string by validation. If validation fails
|
||||
a ValueError will be raised. urllib.parse_qsl will only raise errors if
|
||||
any of name-value pairs omits the equals sign.
|
||||
"""
|
||||
# Check if query contains invalid characters
|
||||
if query and not set(query) <= urlencoded:
|
||||
error = ("Error trying to decode a non urlencoded string. "
|
||||
"Found invalid characters: %s "
|
||||
"in the string: '%s'. "
|
||||
"Please ensure the request/response body is "
|
||||
"x-www-form-urlencoded.")
|
||||
raise ValueError(error % (set(query) - urlencoded, query))
|
||||
|
||||
# Check for correctly hex encoded values using a regular expression
|
||||
# All encoded values begin with % followed by two hex characters
|
||||
# correct = %00, %A0, %0A, %FF
|
||||
# invalid = %G0, %5H, %PO
|
||||
if INVALID_HEX_PATTERN.search(query):
|
||||
raise ValueError('Invalid hex encoding in query string.')
|
||||
|
||||
# We encode to utf-8 prior to parsing because parse_qsl behaves
|
||||
# differently on unicode input in python 2 and 3.
|
||||
# Python 2.7
|
||||
# >>> urlparse.parse_qsl(u'%E5%95%A6%E5%95%A6')
|
||||
# u'\xe5\x95\xa6\xe5\x95\xa6'
|
||||
# Python 2.7, non unicode input gives the same
|
||||
# >>> urlparse.parse_qsl('%E5%95%A6%E5%95%A6')
|
||||
# '\xe5\x95\xa6\xe5\x95\xa6'
|
||||
# but now we can decode it to unicode
|
||||
# >>> urlparse.parse_qsl('%E5%95%A6%E5%95%A6').decode('utf-8')
|
||||
# u'\u5566\u5566'
|
||||
# Python 3.3 however
|
||||
# >>> urllib.parse.parse_qsl(u'%E5%95%A6%E5%95%A6')
|
||||
# u'\u5566\u5566'
|
||||
|
||||
# We want to allow queries such as "c2" whereas urlparse.parse_qsl
|
||||
# with the strict_parsing flag will not.
|
||||
params = urlparse.parse_qsl(query, keep_blank_values=True)
|
||||
|
||||
# unicode all the things
|
||||
decoded = []
|
||||
for k, v in params:
|
||||
decoded.append((to_unicode(k), to_unicode(v)))
|
||||
return decoded
|
||||
|
||||
|
||||
def add_params_to_qs(query, params):
|
||||
"""Extend a query with a list of two-tuples."""
|
||||
if isinstance(params, dict):
|
||||
params = params.items()
|
||||
|
||||
qs = urlparse.parse_qsl(query, keep_blank_values=True)
|
||||
qs.extend(params)
|
||||
return url_encode(qs)
|
||||
|
||||
|
||||
def add_params_to_uri(uri, params, fragment=False):
|
||||
"""Add a list of two-tuples to the uri query components."""
|
||||
sch, net, path, par, query, fra = urlparse.urlparse(uri)
|
||||
if fragment:
|
||||
fra = add_params_to_qs(fra, params)
|
||||
else:
|
||||
query = add_params_to_qs(query, params)
|
||||
return urlparse.urlunparse((sch, net, path, par, query, fra))
|
||||
|
||||
|
||||
def quote(s, safe=b'/'):
|
||||
return to_unicode(_quote(to_bytes(s), safe))
|
||||
|
||||
|
||||
def unquote(s):
|
||||
return to_unicode(_unquote(s))
|
||||
|
||||
|
||||
def quote_url(s):
|
||||
return quote(s, b'~@#$&()*!+=:;,.?/\'')
|
||||
|
||||
|
||||
def extract_params(raw):
|
||||
"""Extract parameters and return them as a list of 2-tuples.
|
||||
|
||||
Will successfully extract parameters from urlencoded query strings,
|
||||
dicts, or lists of 2-tuples. Empty strings/dicts/lists will return an
|
||||
empty list of parameters. Any other input will result in a return
|
||||
value of None.
|
||||
"""
|
||||
if isinstance(raw, (list, tuple)):
|
||||
try:
|
||||
raw = dict(raw)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if isinstance(raw, dict):
|
||||
params = []
|
||||
for k, v in raw.items():
|
||||
params.append((to_unicode(k), to_unicode(v)))
|
||||
return params
|
||||
|
||||
if not raw:
|
||||
return None
|
||||
|
||||
try:
|
||||
return url_decode(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def is_valid_url(url):
|
||||
parsed = urlparse.urlparse(url)
|
||||
return parsed.scheme and parsed.hostname
|
||||
11
myenv/lib/python3.12/site-packages/authlib/consts.py
Normal file
11
myenv/lib/python3.12/site-packages/authlib/consts.py
Normal file
@@ -0,0 +1,11 @@
|
||||
name = 'Authlib'
|
||||
version = '1.3.2'
|
||||
author = 'Hsiaoming Yang <me@lepture.com>'
|
||||
homepage = 'https://authlib.org/'
|
||||
default_user_agent = f'{name}/{version} (+{homepage})'
|
||||
|
||||
default_json_headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Cache-Control', 'no-store'),
|
||||
('Pragma', 'no-cache'),
|
||||
]
|
||||
16
myenv/lib/python3.12/site-packages/authlib/deprecate.py
Normal file
16
myenv/lib/python3.12/site-packages/authlib/deprecate.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import warnings
|
||||
|
||||
|
||||
class AuthlibDeprecationWarning(DeprecationWarning):
|
||||
pass
|
||||
|
||||
|
||||
warnings.simplefilter('always', AuthlibDeprecationWarning)
|
||||
|
||||
|
||||
def deprecate(message, version=None, link_uid=None, link_file=None):
|
||||
if version:
|
||||
message += f'\nIt will be compatible before version {version}.'
|
||||
if link_uid and link_file:
|
||||
message += f'\nRead more <https://git.io/{link_uid}#file-{link_file}-md>'
|
||||
warnings.warn(AuthlibDeprecationWarning(message), stacklevel=2)
|
||||
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
from .registry import BaseOAuth
|
||||
from .sync_app import BaseApp, OAuth1Mixin, OAuth2Mixin
|
||||
from .sync_openid import OpenIDMixin
|
||||
from .framework_integration import FrameworkIntegration
|
||||
from .errors import (
|
||||
OAuthError, MissingRequestTokenError, MissingTokenError,
|
||||
TokenExpiredError, InvalidTokenError, UnsupportedTokenTypeError,
|
||||
MismatchingStateError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BaseOAuth',
|
||||
'BaseApp', 'OAuth1Mixin', 'OAuth2Mixin',
|
||||
'OpenIDMixin', 'FrameworkIntegration',
|
||||
'OAuthError', 'MissingRequestTokenError', 'MissingTokenError',
|
||||
'TokenExpiredError', 'InvalidTokenError', 'UnsupportedTokenTypeError',
|
||||
'MismatchingStateError',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,144 @@
|
||||
import time
|
||||
import logging
|
||||
from authlib.common.urls import urlparse
|
||||
from .errors import (
|
||||
MissingRequestTokenError,
|
||||
MissingTokenError,
|
||||
)
|
||||
from .sync_app import OAuth1Base, OAuth2Base
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['AsyncOAuth1Mixin', 'AsyncOAuth2Mixin']
|
||||
|
||||
|
||||
class AsyncOAuth1Mixin(OAuth1Base):
|
||||
async def request(self, method, url, token=None, **kwargs):
|
||||
async with self._get_oauth_client() as session:
|
||||
return await _http_request(self, session, method, url, token, kwargs)
|
||||
|
||||
async def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
if not self.authorize_url:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
async with self._get_oauth_client() as client:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = {}
|
||||
if self.request_token_params:
|
||||
params.update(self.request_token_params)
|
||||
request_token = await client.fetch_request_token(self.request_token_url, **params)
|
||||
log.debug(f'Fetch request token: {request_token!r}')
|
||||
url = client.create_authorization_url(self.authorize_url, **kwargs)
|
||||
state = request_token['oauth_token']
|
||||
return {'url': url, 'request_token': request_token, 'state': state}
|
||||
|
||||
async def fetch_access_token(self, request_token=None, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:param request_token: A previous request token for OAuth 1.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
async with self._get_oauth_client() as client:
|
||||
if request_token is None:
|
||||
raise MissingRequestTokenError()
|
||||
# merge request token with verifier
|
||||
token = {}
|
||||
token.update(request_token)
|
||||
token.update(kwargs)
|
||||
client.token = token
|
||||
params = self.access_token_params or {}
|
||||
token = await client.fetch_access_token(self.access_token_url, **params)
|
||||
return token
|
||||
|
||||
|
||||
class AsyncOAuth2Mixin(OAuth2Base):
|
||||
async def _on_update_token(self, token, refresh_token=None, access_token=None):
|
||||
if self._update_token:
|
||||
await self._update_token(
|
||||
token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
async def load_server_metadata(self):
|
||||
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
|
||||
async with self.client_cls(**self.client_kwargs) as client:
|
||||
resp = await client.request('GET', self._server_metadata_url, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
metadata = resp.json()
|
||||
metadata['_loaded_at'] = time.time()
|
||||
self.server_metadata.update(metadata)
|
||||
return self.server_metadata
|
||||
|
||||
async def request(self, method, url, token=None, **kwargs):
|
||||
metadata = await self.load_server_metadata()
|
||||
async with self._get_oauth_client(**metadata) as session:
|
||||
return await _http_request(self, session, method, url, token, kwargs)
|
||||
|
||||
async def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
metadata = await self.load_server_metadata()
|
||||
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
|
||||
if not authorization_endpoint:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
async with self._get_oauth_client(**metadata) as client:
|
||||
client.redirect_uri = redirect_uri
|
||||
return self._create_oauth2_authorization_url(
|
||||
client, authorization_endpoint, **kwargs)
|
||||
|
||||
async def fetch_access_token(self, redirect_uri=None, **kwargs):
|
||||
"""Fetch access token in the final step.
|
||||
|
||||
:param redirect_uri: Callback or Redirect URI that is used in
|
||||
previous :meth:`authorize_redirect`.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
metadata = await self.load_server_metadata()
|
||||
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
|
||||
async with self._get_oauth_client(**metadata) as client:
|
||||
if redirect_uri is not None:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = {}
|
||||
if self.access_token_params:
|
||||
params.update(self.access_token_params)
|
||||
params.update(kwargs)
|
||||
token = await client.fetch_token(token_endpoint, **params)
|
||||
return token
|
||||
|
||||
|
||||
async def _http_request(ctx, session, method, url, token, kwargs):
|
||||
request = kwargs.pop('request', None)
|
||||
withhold_token = kwargs.get('withhold_token')
|
||||
if ctx.api_base_url and not url.startswith(('https://', 'http://')):
|
||||
url = urlparse.urljoin(ctx.api_base_url, url)
|
||||
|
||||
if withhold_token:
|
||||
return await session.request(method, url, **kwargs)
|
||||
|
||||
if token is None and ctx._fetch_token and request:
|
||||
token = await ctx._fetch_token(request)
|
||||
if token is None:
|
||||
raise MissingTokenError()
|
||||
|
||||
session.token = token
|
||||
return await session.request(method, url, **kwargs)
|
||||
@@ -0,0 +1,79 @@
|
||||
from authlib.jose import JsonWebToken, JsonWebKey
|
||||
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
|
||||
|
||||
__all__ = ['AsyncOpenIDMixin']
|
||||
|
||||
|
||||
class AsyncOpenIDMixin:
|
||||
async def fetch_jwk_set(self, force=False):
|
||||
metadata = await self.load_server_metadata()
|
||||
jwk_set = metadata.get('jwks')
|
||||
if jwk_set and not force:
|
||||
return jwk_set
|
||||
|
||||
uri = metadata.get('jwks_uri')
|
||||
if not uri:
|
||||
raise RuntimeError('Missing "jwks_uri" in metadata')
|
||||
|
||||
async with self.client_cls(**self.client_kwargs) as client:
|
||||
resp = await client.request('GET', uri, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
jwk_set = resp.json()
|
||||
|
||||
self.server_metadata['jwks'] = jwk_set
|
||||
return jwk_set
|
||||
|
||||
async def userinfo(self, **kwargs):
|
||||
"""Fetch user info from ``userinfo_endpoint``."""
|
||||
metadata = await self.load_server_metadata()
|
||||
resp = await self.get(metadata['userinfo_endpoint'], **kwargs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return UserInfo(data)
|
||||
|
||||
async def parse_id_token(self, token, nonce, claims_options=None):
|
||||
"""Return an instance of UserInfo from token's ``id_token``."""
|
||||
claims_params = dict(
|
||||
nonce=nonce,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
if 'access_token' in token:
|
||||
claims_params['access_token'] = token['access_token']
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
metadata = await self.load_server_metadata()
|
||||
if claims_options is None and 'issuer' in metadata:
|
||||
claims_options = {'iss': {'values': [metadata['issuer']]}}
|
||||
|
||||
alg_values = metadata.get('id_token_signing_alg_values_supported')
|
||||
if not alg_values:
|
||||
alg_values = ['RS256']
|
||||
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
||||
jwk_set = await self.fetch_jwk_set()
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token['id_token'],
|
||||
key=JsonWebKey.import_key_set(jwk_set),
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
except ValueError:
|
||||
jwk_set = await self.fetch_jwk_set(force=True)
|
||||
claims = jwt.decode(
|
||||
token['id_token'],
|
||||
key=JsonWebKey.import_key_set(jwk_set),
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
||||
# https://github.com/lepture/authlib/issues/259
|
||||
if claims.get('nonce_supported') is False:
|
||||
claims.params['nonce'] = None
|
||||
claims.validate(leeway=120)
|
||||
return UserInfo(claims)
|
||||
@@ -0,0 +1,30 @@
|
||||
from authlib.common.errors import AuthlibBaseError
|
||||
|
||||
|
||||
class OAuthError(AuthlibBaseError):
|
||||
error = 'oauth_error'
|
||||
|
||||
|
||||
class MissingRequestTokenError(OAuthError):
|
||||
error = 'missing_request_token'
|
||||
|
||||
|
||||
class MissingTokenError(OAuthError):
|
||||
error = 'missing_token'
|
||||
|
||||
|
||||
class TokenExpiredError(OAuthError):
|
||||
error = 'token_expired'
|
||||
|
||||
|
||||
class InvalidTokenError(OAuthError):
|
||||
error = 'token_invalid'
|
||||
|
||||
|
||||
class UnsupportedTokenTypeError(OAuthError):
|
||||
error = 'unsupported_token_type'
|
||||
|
||||
|
||||
class MismatchingStateError(OAuthError):
|
||||
error = 'mismatching_state'
|
||||
description = 'CSRF Warning! State not equal in request and response.'
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
class FrameworkIntegration:
|
||||
expires_in = 3600
|
||||
|
||||
def __init__(self, name, cache=None):
|
||||
self.name = name
|
||||
self.cache = cache
|
||||
|
||||
def _get_cache_data(self, key):
|
||||
value = self.cache.get(key)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _clear_session_state(self, session):
|
||||
now = time.time()
|
||||
for key in dict(session):
|
||||
if '_authlib_' in key:
|
||||
# TODO: remove in future
|
||||
session.pop(key)
|
||||
elif key.startswith('_state_'):
|
||||
value = session[key]
|
||||
exp = value.get('exp')
|
||||
if not exp or exp < now:
|
||||
session.pop(key)
|
||||
|
||||
def get_state_data(self, session, state):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
value = self._get_cache_data(key)
|
||||
else:
|
||||
value = session.get(key)
|
||||
if value:
|
||||
return value.get('data')
|
||||
return None
|
||||
|
||||
def set_state_data(self, session, state, data):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
self.cache.set(key, json.dumps({'data': data}), self.expires_in)
|
||||
else:
|
||||
now = time.time()
|
||||
session[key] = {'data': data, 'exp': now + self.expires_in}
|
||||
|
||||
def clear_state_data(self, session, state):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
self.cache.delete(key)
|
||||
else:
|
||||
session.pop(key, None)
|
||||
self._clear_session_state(session)
|
||||
|
||||
def update_token(self, token, refresh_token=None, access_token=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def load_config(oauth, name, params):
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,131 @@
|
||||
import functools
|
||||
from .framework_integration import FrameworkIntegration
|
||||
|
||||
__all__ = ['BaseOAuth']
|
||||
|
||||
|
||||
OAUTH_CLIENT_PARAMS = (
|
||||
'client_id', 'client_secret',
|
||||
'request_token_url', 'request_token_params',
|
||||
'access_token_url', 'access_token_params',
|
||||
'refresh_token_url', 'refresh_token_params',
|
||||
'authorize_url', 'authorize_params',
|
||||
'api_base_url', 'client_kwargs',
|
||||
'server_metadata_url',
|
||||
)
|
||||
|
||||
|
||||
class BaseOAuth:
|
||||
"""Registry for oauth clients.
|
||||
|
||||
Create an instance for registry::
|
||||
|
||||
oauth = OAuth()
|
||||
"""
|
||||
oauth1_client_cls = None
|
||||
oauth2_client_cls = None
|
||||
framework_integration_cls = FrameworkIntegration
|
||||
|
||||
def __init__(self, cache=None, fetch_token=None, update_token=None):
|
||||
self._registry = {}
|
||||
self._clients = {}
|
||||
self.cache = cache
|
||||
self.fetch_token = fetch_token
|
||||
self.update_token = update_token
|
||||
|
||||
def create_client(self, name):
|
||||
"""Create or get the given named OAuth client. For instance, the
|
||||
OAuth registry has ``.register`` a twitter client, developers may
|
||||
access the client with::
|
||||
|
||||
client = oauth.create_client('twitter')
|
||||
|
||||
:param: name: Name of the remote application
|
||||
:return: OAuth remote app
|
||||
"""
|
||||
if name in self._clients:
|
||||
return self._clients[name]
|
||||
|
||||
if name not in self._registry:
|
||||
return None
|
||||
|
||||
overwrite, config = self._registry[name]
|
||||
client_cls = config.pop('client_cls', None)
|
||||
|
||||
if client_cls and client_cls.OAUTH_APP_CONFIG:
|
||||
kwargs = client_cls.OAUTH_APP_CONFIG
|
||||
kwargs.update(config)
|
||||
else:
|
||||
kwargs = config
|
||||
|
||||
kwargs = self.generate_client_kwargs(name, overwrite, **kwargs)
|
||||
framework = self.framework_integration_cls(name, self.cache)
|
||||
if client_cls:
|
||||
client = client_cls(framework, name, **kwargs)
|
||||
elif kwargs.get('request_token_url'):
|
||||
client = self.oauth1_client_cls(framework, name, **kwargs)
|
||||
else:
|
||||
client = self.oauth2_client_cls(framework, name, **kwargs)
|
||||
|
||||
self._clients[name] = client
|
||||
return client
|
||||
|
||||
def register(self, name, overwrite=False, **kwargs):
|
||||
"""Registers a new remote application.
|
||||
|
||||
:param name: Name of the remote application.
|
||||
:param overwrite: Overwrite existing config with framework settings.
|
||||
:param kwargs: Parameters for :class:`RemoteApp`.
|
||||
|
||||
Find parameters for the given remote app class. When a remote app is
|
||||
registered, it can be accessed with *named* attribute::
|
||||
|
||||
oauth.register('twitter', client_id='', ...)
|
||||
oauth.twitter.get('timeline')
|
||||
"""
|
||||
self._registry[name] = (overwrite, kwargs)
|
||||
return self.create_client(name)
|
||||
|
||||
def generate_client_kwargs(self, name, overwrite, **kwargs):
|
||||
fetch_token = kwargs.pop('fetch_token', None)
|
||||
update_token = kwargs.pop('update_token', None)
|
||||
|
||||
config = self.load_config(name, OAUTH_CLIENT_PARAMS)
|
||||
if config:
|
||||
kwargs = _config_client(config, kwargs, overwrite)
|
||||
|
||||
if not fetch_token and self.fetch_token:
|
||||
fetch_token = functools.partial(self.fetch_token, name)
|
||||
|
||||
kwargs['fetch_token'] = fetch_token
|
||||
|
||||
if not kwargs.get('request_token_url'):
|
||||
if not update_token and self.update_token:
|
||||
update_token = functools.partial(self.update_token, name)
|
||||
|
||||
kwargs['update_token'] = update_token
|
||||
return kwargs
|
||||
|
||||
def load_config(self, name, params):
|
||||
return self.framework_integration_cls.load_config(self, name, params)
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
if key in self._registry:
|
||||
return self.create_client(key)
|
||||
raise AttributeError('No such client: %s' % key)
|
||||
|
||||
|
||||
def _config_client(config, kwargs, overwrite):
|
||||
for k in OAUTH_CLIENT_PARAMS:
|
||||
v = config.get(k, None)
|
||||
if k not in kwargs:
|
||||
kwargs[k] = v
|
||||
elif overwrite and v:
|
||||
if isinstance(kwargs[k], dict):
|
||||
kwargs[k].update(v)
|
||||
else:
|
||||
kwargs[k] = v
|
||||
return kwargs
|
||||
@@ -0,0 +1,348 @@
|
||||
import time
|
||||
import logging
|
||||
from authlib.common.urls import urlparse
|
||||
from authlib.consts import default_user_agent
|
||||
from authlib.common.security import generate_token
|
||||
from .errors import (
|
||||
MismatchingStateError,
|
||||
MissingRequestTokenError,
|
||||
MissingTokenError,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseApp:
|
||||
client_cls = None
|
||||
OAUTH_APP_CONFIG = None
|
||||
|
||||
def request(self, method, url, token=None, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
"""Invoke GET http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.get('users/lepture')
|
||||
"""
|
||||
return self.request('GET', url, **kwargs)
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
"""Invoke POST http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.post('timeline', json={'text': 'Hi'})
|
||||
"""
|
||||
return self.request('POST', url, **kwargs)
|
||||
|
||||
def patch(self, url, **kwargs):
|
||||
"""Invoke PATCH http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.patch('profile', json={'name': 'Hsiaoming Yang'})
|
||||
"""
|
||||
return self.request('PATCH', url, **kwargs)
|
||||
|
||||
def put(self, url, **kwargs):
|
||||
"""Invoke PUT http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.put('profile', json={'name': 'Hsiaoming Yang'})
|
||||
"""
|
||||
return self.request('PUT', url, **kwargs)
|
||||
|
||||
def delete(self, url, **kwargs):
|
||||
"""Invoke DELETE http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.delete('posts/123')
|
||||
"""
|
||||
return self.request('DELETE', url, **kwargs)
|
||||
|
||||
|
||||
class _RequestMixin:
|
||||
def _get_requested_token(self, request):
|
||||
if self._fetch_token and request:
|
||||
return self._fetch_token(request)
|
||||
|
||||
def _send_token_request(self, session, method, url, token, kwargs):
|
||||
request = kwargs.pop('request', None)
|
||||
withhold_token = kwargs.get('withhold_token')
|
||||
if self.api_base_url and not url.startswith(('https://', 'http://')):
|
||||
url = urlparse.urljoin(self.api_base_url, url)
|
||||
|
||||
if withhold_token:
|
||||
return session.request(method, url, **kwargs)
|
||||
|
||||
if token is None:
|
||||
token = self._get_requested_token(request)
|
||||
|
||||
if token is None:
|
||||
raise MissingTokenError()
|
||||
|
||||
session.token = token
|
||||
return session.request(method, url, **kwargs)
|
||||
|
||||
|
||||
class OAuth1Base:
|
||||
client_cls = None
|
||||
|
||||
def __init__(
|
||||
self, framework, name=None, fetch_token=None,
|
||||
client_id=None, client_secret=None,
|
||||
request_token_url=None, request_token_params=None,
|
||||
access_token_url=None, access_token_params=None,
|
||||
authorize_url=None, authorize_params=None,
|
||||
api_base_url=None, client_kwargs=None, user_agent=None, **kwargs):
|
||||
self.framework = framework
|
||||
self.name = name
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.request_token_url = request_token_url
|
||||
self.request_token_params = request_token_params
|
||||
self.access_token_url = access_token_url
|
||||
self.access_token_params = access_token_params
|
||||
self.authorize_url = authorize_url
|
||||
self.authorize_params = authorize_params
|
||||
self.api_base_url = api_base_url
|
||||
self.client_kwargs = client_kwargs or {}
|
||||
|
||||
self._fetch_token = fetch_token
|
||||
self._user_agent = user_agent or default_user_agent
|
||||
self._kwargs = kwargs
|
||||
|
||||
def _get_oauth_client(self):
|
||||
session = self.client_cls(self.client_id, self.client_secret, **self.client_kwargs)
|
||||
session.headers['User-Agent'] = self._user_agent
|
||||
return session
|
||||
|
||||
|
||||
class OAuth1Mixin(_RequestMixin, OAuth1Base):
|
||||
def request(self, method, url, token=None, **kwargs):
|
||||
with self._get_oauth_client() as session:
|
||||
return self._send_token_request(session, method, url, token, kwargs)
|
||||
|
||||
def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
if not self.authorize_url:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
with self._get_oauth_client() as client:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = self.request_token_params or {}
|
||||
request_token = client.fetch_request_token(self.request_token_url, **params)
|
||||
log.debug(f'Fetch request token: {request_token!r}')
|
||||
url = client.create_authorization_url(self.authorize_url, **kwargs)
|
||||
state = request_token['oauth_token']
|
||||
return {'url': url, 'request_token': request_token, 'state': state}
|
||||
|
||||
def fetch_access_token(self, request_token=None, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:param request_token: A previous request token for OAuth 1.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
with self._get_oauth_client() as client:
|
||||
if request_token is None:
|
||||
raise MissingRequestTokenError()
|
||||
# merge request token with verifier
|
||||
token = {}
|
||||
token.update(request_token)
|
||||
token.update(kwargs)
|
||||
client.token = token
|
||||
params = self.access_token_params or {}
|
||||
token = client.fetch_access_token(self.access_token_url, **params)
|
||||
return token
|
||||
|
||||
|
||||
class OAuth2Base:
|
||||
client_cls = None
|
||||
|
||||
def __init__(
|
||||
self, framework, name=None, fetch_token=None, update_token=None,
|
||||
client_id=None, client_secret=None,
|
||||
access_token_url=None, access_token_params=None,
|
||||
authorize_url=None, authorize_params=None,
|
||||
api_base_url=None, client_kwargs=None, server_metadata_url=None,
|
||||
compliance_fix=None, client_auth_methods=None, user_agent=None, **kwargs):
|
||||
self.framework = framework
|
||||
self.name = name
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.access_token_url = access_token_url
|
||||
self.access_token_params = access_token_params
|
||||
self.authorize_url = authorize_url
|
||||
self.authorize_params = authorize_params
|
||||
self.api_base_url = api_base_url
|
||||
self.client_kwargs = client_kwargs or {}
|
||||
|
||||
self.compliance_fix = compliance_fix
|
||||
self.client_auth_methods = client_auth_methods
|
||||
self._fetch_token = fetch_token
|
||||
self._update_token = update_token
|
||||
self._user_agent = user_agent or default_user_agent
|
||||
|
||||
self._server_metadata_url = server_metadata_url
|
||||
self.server_metadata = kwargs
|
||||
|
||||
def _on_update_token(self, token, refresh_token=None, access_token=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_oauth_client(self, **metadata):
|
||||
client_kwargs = {}
|
||||
client_kwargs.update(self.client_kwargs)
|
||||
client_kwargs.update(metadata)
|
||||
|
||||
if self.authorize_url:
|
||||
client_kwargs['authorization_endpoint'] = self.authorize_url
|
||||
if self.access_token_url:
|
||||
client_kwargs['token_endpoint'] = self.access_token_url
|
||||
|
||||
session = self.client_cls(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
update_token=self._on_update_token,
|
||||
**client_kwargs
|
||||
)
|
||||
if self.client_auth_methods:
|
||||
for f in self.client_auth_methods:
|
||||
session.register_client_auth_method(f)
|
||||
|
||||
if self.compliance_fix:
|
||||
self.compliance_fix(session)
|
||||
|
||||
session.headers['User-Agent'] = self._user_agent
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def _format_state_params(state_data, params):
|
||||
if state_data is None:
|
||||
raise MismatchingStateError()
|
||||
|
||||
code_verifier = state_data.get('code_verifier')
|
||||
if code_verifier:
|
||||
params['code_verifier'] = code_verifier
|
||||
|
||||
redirect_uri = state_data.get('redirect_uri')
|
||||
if redirect_uri:
|
||||
params['redirect_uri'] = redirect_uri
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs):
|
||||
rv = {}
|
||||
if client.code_challenge_method:
|
||||
code_verifier = kwargs.get('code_verifier')
|
||||
if not code_verifier:
|
||||
code_verifier = generate_token(48)
|
||||
kwargs['code_verifier'] = code_verifier
|
||||
rv['code_verifier'] = code_verifier
|
||||
log.debug(f'Using code_verifier: {code_verifier!r}')
|
||||
|
||||
scope = kwargs.get('scope', client.scope)
|
||||
scope = (
|
||||
(scope if isinstance(scope, (list, tuple)) else scope.split())
|
||||
if scope
|
||||
else None
|
||||
)
|
||||
if scope and "openid" in scope:
|
||||
# this is an OpenID Connect service
|
||||
nonce = kwargs.get('nonce')
|
||||
if not nonce:
|
||||
nonce = generate_token(20)
|
||||
kwargs['nonce'] = nonce
|
||||
rv['nonce'] = nonce
|
||||
|
||||
url, state = client.create_authorization_url(
|
||||
authorization_endpoint, **kwargs)
|
||||
rv['url'] = url
|
||||
rv['state'] = state
|
||||
return rv
|
||||
|
||||
|
||||
class OAuth2Mixin(_RequestMixin, OAuth2Base):
|
||||
def _on_update_token(self, token, refresh_token=None, access_token=None):
|
||||
if callable(self._update_token):
|
||||
self._update_token(
|
||||
token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.framework.update_token(
|
||||
token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
def request(self, method, url, token=None, **kwargs):
|
||||
metadata = self.load_server_metadata()
|
||||
with self._get_oauth_client(**metadata) as session:
|
||||
return self._send_token_request(session, method, url, token, kwargs)
|
||||
|
||||
def load_server_metadata(self):
|
||||
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
|
||||
with self.client_cls(**self.client_kwargs) as session:
|
||||
resp = session.request('GET', self._server_metadata_url, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
metadata = resp.json()
|
||||
|
||||
metadata['_loaded_at'] = time.time()
|
||||
self.server_metadata.update(metadata)
|
||||
return self.server_metadata
|
||||
|
||||
def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
metadata = self.load_server_metadata()
|
||||
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
|
||||
|
||||
if not authorization_endpoint:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
|
||||
with self._get_oauth_client(**metadata) as client:
|
||||
if redirect_uri is not None:
|
||||
client.redirect_uri = redirect_uri
|
||||
return self._create_oauth2_authorization_url(
|
||||
client, authorization_endpoint, **kwargs)
|
||||
|
||||
def fetch_access_token(self, redirect_uri=None, **kwargs):
|
||||
"""Fetch access token in the final step.
|
||||
|
||||
:param redirect_uri: Callback or Redirect URI that is used in
|
||||
previous :meth:`authorize_redirect`.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
metadata = self.load_server_metadata()
|
||||
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
|
||||
with self._get_oauth_client(**metadata) as client:
|
||||
if redirect_uri is not None:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = {}
|
||||
if self.access_token_params:
|
||||
params.update(self.access_token_params)
|
||||
params.update(kwargs)
|
||||
token = client.fetch_token(token_endpoint, **params)
|
||||
return token
|
||||
@@ -0,0 +1,82 @@
|
||||
from authlib.jose import jwt, JsonWebToken, JsonWebKey
|
||||
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
|
||||
|
||||
|
||||
class OpenIDMixin:
|
||||
def fetch_jwk_set(self, force=False):
|
||||
metadata = self.load_server_metadata()
|
||||
jwk_set = metadata.get('jwks')
|
||||
if jwk_set and not force:
|
||||
return jwk_set
|
||||
|
||||
uri = metadata.get('jwks_uri')
|
||||
if not uri:
|
||||
raise RuntimeError('Missing "jwks_uri" in metadata')
|
||||
|
||||
with self.client_cls(**self.client_kwargs) as session:
|
||||
resp = session.request('GET', uri, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
jwk_set = resp.json()
|
||||
|
||||
self.server_metadata['jwks'] = jwk_set
|
||||
return jwk_set
|
||||
|
||||
def userinfo(self, **kwargs):
|
||||
"""Fetch user info from ``userinfo_endpoint``."""
|
||||
metadata = self.load_server_metadata()
|
||||
resp = self.get(metadata['userinfo_endpoint'], **kwargs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return UserInfo(data)
|
||||
|
||||
def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
|
||||
"""Return an instance of UserInfo from token's ``id_token``."""
|
||||
if 'id_token' not in token:
|
||||
return None
|
||||
|
||||
load_key = self.create_load_key()
|
||||
|
||||
claims_params = dict(
|
||||
nonce=nonce,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
if 'access_token' in token:
|
||||
claims_params['access_token'] = token['access_token']
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
metadata = self.load_server_metadata()
|
||||
if claims_options is None and 'issuer' in metadata:
|
||||
claims_options = {'iss': {'values': [metadata['issuer']]}}
|
||||
|
||||
alg_values = metadata.get('id_token_signing_alg_values_supported')
|
||||
if alg_values:
|
||||
_jwt = JsonWebToken(alg_values)
|
||||
else:
|
||||
_jwt = jwt
|
||||
|
||||
claims = _jwt.decode(
|
||||
token['id_token'], key=load_key,
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
# https://github.com/lepture/authlib/issues/259
|
||||
if claims.get('nonce_supported') is False:
|
||||
claims.params['nonce'] = None
|
||||
|
||||
claims.validate(leeway=leeway)
|
||||
return UserInfo(claims)
|
||||
|
||||
def create_load_key(self):
|
||||
def load_key(header, _):
|
||||
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
|
||||
try:
|
||||
return jwk_set.find_by_kid(header.get('kid'))
|
||||
except ValueError:
|
||||
# re-try with new jwk set
|
||||
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
|
||||
return jwk_set.find_by_kid(header.get('kid'))
|
||||
|
||||
return load_key
|
||||
@@ -0,0 +1,19 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .integration import DjangoIntegration, token_update
|
||||
from .apps import DjangoOAuth1App, DjangoOAuth2App
|
||||
from ..base_client import BaseOAuth, OAuthError
|
||||
|
||||
|
||||
class OAuth(BaseOAuth):
|
||||
oauth1_client_cls = DjangoOAuth1App
|
||||
oauth2_client_cls = DjangoOAuth2App
|
||||
framework_integration_cls = DjangoIntegration
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuth',
|
||||
'DjangoOAuth1App', 'DjangoOAuth2App',
|
||||
'DjangoIntegration',
|
||||
'token_update', 'OAuthError',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,87 @@
|
||||
from django.http import HttpResponseRedirect
|
||||
from ..requests_client import OAuth1Session, OAuth2Session
|
||||
from ..base_client import (
|
||||
BaseApp, OAuthError,
|
||||
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
|
||||
)
|
||||
|
||||
|
||||
class DjangoAppMixin:
|
||||
def save_authorize_data(self, request, **kwargs):
|
||||
state = kwargs.pop('state', None)
|
||||
if state:
|
||||
self.framework.set_state_data(request.session, state, kwargs)
|
||||
else:
|
||||
raise RuntimeError('Missing state value')
|
||||
|
||||
def authorize_redirect(self, request, redirect_uri=None, **kwargs):
|
||||
"""Create a HTTP Redirect for Authorization Endpoint.
|
||||
|
||||
:param request: HTTP request instance from Django view.
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: A HTTP redirect response.
|
||||
"""
|
||||
rv = self.create_authorization_url(redirect_uri, **kwargs)
|
||||
self.save_authorize_data(request, redirect_uri=redirect_uri, **rv)
|
||||
return HttpResponseRedirect(rv['url'])
|
||||
|
||||
|
||||
class DjangoOAuth1App(DjangoAppMixin, OAuth1Mixin, BaseApp):
|
||||
client_cls = OAuth1Session
|
||||
|
||||
def authorize_access_token(self, request, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:param request: HTTP request instance from Django view.
|
||||
:return: A token dict.
|
||||
"""
|
||||
params = request.GET.dict()
|
||||
state = params.get('oauth_token')
|
||||
if not state:
|
||||
raise OAuthError(description='Missing "oauth_token" parameter')
|
||||
|
||||
data = self.framework.get_state_data(request.session, state)
|
||||
if not data:
|
||||
raise OAuthError(description='Missing "request_token" in temporary data')
|
||||
|
||||
params['request_token'] = data['request_token']
|
||||
params.update(kwargs)
|
||||
self.framework.clear_state_data(request.session, state)
|
||||
return self.fetch_access_token(**params)
|
||||
|
||||
|
||||
class DjangoOAuth2App(DjangoAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
|
||||
client_cls = OAuth2Session
|
||||
|
||||
def authorize_access_token(self, request, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:param request: HTTP request instance from Django view.
|
||||
:return: A token dict.
|
||||
"""
|
||||
if request.method == 'GET':
|
||||
error = request.GET.get('error')
|
||||
if error:
|
||||
description = request.GET.get('error_description')
|
||||
raise OAuthError(error=error, description=description)
|
||||
params = {
|
||||
'code': request.GET.get('code'),
|
||||
'state': request.GET.get('state'),
|
||||
}
|
||||
else:
|
||||
params = {
|
||||
'code': request.POST.get('code'),
|
||||
'state': request.POST.get('state'),
|
||||
}
|
||||
|
||||
claims_options = kwargs.pop('claims_options', None)
|
||||
state_data = self.framework.get_state_data(request.session, params.get('state'))
|
||||
self.framework.clear_state_data(request.session, params.get('state'))
|
||||
params = self._format_state_params(state_data, params)
|
||||
token = self.fetch_access_token(**params, **kwargs)
|
||||
|
||||
if 'id_token' in token and 'nonce' in state_data:
|
||||
userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
|
||||
token['userinfo'] = userinfo
|
||||
return token
|
||||
@@ -0,0 +1,22 @@
|
||||
from django.conf import settings
|
||||
from django.dispatch import Signal
|
||||
from ..base_client import FrameworkIntegration
|
||||
|
||||
token_update = Signal()
|
||||
|
||||
|
||||
class DjangoIntegration(FrameworkIntegration):
|
||||
def update_token(self, token, refresh_token=None, access_token=None):
|
||||
token_update.send(
|
||||
sender=self.__class__,
|
||||
name=self.name,
|
||||
token=token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_config(oauth, name, params):
|
||||
config = getattr(settings, 'AUTHLIB_OAUTH_CLIENTS', None)
|
||||
if config:
|
||||
return config.get(name)
|
||||
@@ -0,0 +1,9 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .authorization_server import (
|
||||
BaseServer, CacheAuthorizationServer
|
||||
)
|
||||
from .resource_protector import ResourceProtector
|
||||
|
||||
|
||||
__all__ = ['BaseServer', 'CacheAuthorizationServer', 'ResourceProtector']
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,125 @@
|
||||
import logging
|
||||
from authlib.oauth1 import (
|
||||
OAuth1Request,
|
||||
AuthorizationServer as _AuthorizationServer,
|
||||
)
|
||||
from authlib.oauth1 import TemporaryCredential
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.common.urls import url_encode
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
from django.http import HttpResponse
|
||||
from .nonce import exists_nonce_in_cache
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseServer(_AuthorizationServer):
|
||||
def __init__(self, client_model, token_model, token_generator=None):
|
||||
self.client_model = client_model
|
||||
self.token_model = token_model
|
||||
|
||||
if token_generator is None:
|
||||
def token_generator():
|
||||
return {
|
||||
'oauth_token': generate_token(42),
|
||||
'oauth_token_secret': generate_token(48)
|
||||
}
|
||||
|
||||
self.token_generator = token_generator
|
||||
self._config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {})
|
||||
self._nonce_expires_in = self._config.get('nonce_expires_in', 86400)
|
||||
methods = self._config.get('signature_methods')
|
||||
if methods:
|
||||
self.SUPPORTED_SIGNATURE_METHODS = methods
|
||||
|
||||
def get_client_by_id(self, client_id):
|
||||
try:
|
||||
return self.client_model.objects.get(client_id=client_id)
|
||||
except self.client_model.DoesNotExist:
|
||||
return None
|
||||
|
||||
def exists_nonce(self, nonce, request):
|
||||
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
|
||||
|
||||
def create_token_credential(self, request):
|
||||
temporary_credential = request.credential
|
||||
token = self.token_generator()
|
||||
item = self.token_model(
|
||||
oauth_token=token['oauth_token'],
|
||||
oauth_token_secret=token['oauth_token_secret'],
|
||||
user_id=temporary_credential.get_user_id(),
|
||||
client_id=temporary_credential.get_client_id()
|
||||
)
|
||||
item.save()
|
||||
return item
|
||||
|
||||
def check_authorization_request(self, request):
|
||||
req = self.create_oauth1_request(request)
|
||||
self.validate_authorization_request(req)
|
||||
return req
|
||||
|
||||
def create_oauth1_request(self, request):
|
||||
if request.method == 'POST':
|
||||
body = request.POST.dict()
|
||||
else:
|
||||
body = None
|
||||
url = request.build_absolute_uri()
|
||||
return OAuth1Request(request.method, url, body, request.headers)
|
||||
|
||||
def handle_response(self, status_code, payload, headers):
|
||||
resp = HttpResponse(url_encode(payload), status=status_code)
|
||||
for k, v in headers:
|
||||
resp[k] = v
|
||||
return resp
|
||||
|
||||
|
||||
class CacheAuthorizationServer(BaseServer):
|
||||
def __init__(self, client_model, token_model, token_generator=None):
|
||||
super().__init__(
|
||||
client_model, token_model, token_generator)
|
||||
self._temporary_expires_in = self._config.get(
|
||||
'temporary_credential_expires_in', 86400)
|
||||
self._temporary_credential_key_prefix = self._config.get(
|
||||
'temporary_credential_key_prefix', 'temporary_credential:')
|
||||
|
||||
def create_temporary_credential(self, request):
|
||||
key_prefix = self._temporary_credential_key_prefix
|
||||
token = self.token_generator()
|
||||
|
||||
client_id = request.client_id
|
||||
redirect_uri = request.redirect_uri
|
||||
key = key_prefix + token['oauth_token']
|
||||
token['client_id'] = client_id
|
||||
if redirect_uri:
|
||||
token['oauth_callback'] = redirect_uri
|
||||
|
||||
cache.set(key, token, timeout=self._temporary_expires_in)
|
||||
return TemporaryCredential(token)
|
||||
|
||||
def get_temporary_credential(self, request):
|
||||
if not request.token:
|
||||
return None
|
||||
|
||||
key_prefix = self._temporary_credential_key_prefix
|
||||
key = key_prefix + request.token
|
||||
value = cache.get(key)
|
||||
if value:
|
||||
return TemporaryCredential(value)
|
||||
|
||||
def delete_temporary_credential(self, request):
|
||||
if request.token:
|
||||
key_prefix = self._temporary_credential_key_prefix
|
||||
key = key_prefix + request.token
|
||||
cache.delete(key)
|
||||
|
||||
def create_authorization_verifier(self, request):
|
||||
key_prefix = self._temporary_credential_key_prefix
|
||||
verifier = generate_token(36)
|
||||
credential = request.credential
|
||||
user = request.user
|
||||
key = key_prefix + credential.get_oauth_token()
|
||||
credential['oauth_verifier'] = verifier
|
||||
credential['user_id'] = user.pk
|
||||
cache.set(key, credential, timeout=self._temporary_expires_in)
|
||||
return verifier
|
||||
@@ -0,0 +1,15 @@
|
||||
from django.core.cache import cache
|
||||
|
||||
|
||||
def exists_nonce_in_cache(nonce, request, timeout):
|
||||
key_prefix = 'nonce:'
|
||||
timestamp = request.timestamp
|
||||
client_id = request.client_id
|
||||
token = request.token
|
||||
key = f'{key_prefix}{nonce}-{timestamp}-{client_id}'
|
||||
if token:
|
||||
key = f'{key}-{token}'
|
||||
|
||||
rv = bool(cache.get(key))
|
||||
cache.set(key, 1, timeout=timeout)
|
||||
return rv
|
||||
@@ -0,0 +1,64 @@
|
||||
import functools
|
||||
from authlib.oauth1.errors import OAuth1Error
|
||||
from authlib.oauth1 import ResourceProtector as _ResourceProtector
|
||||
from django.http import JsonResponse
|
||||
from django.conf import settings
|
||||
from .nonce import exists_nonce_in_cache
|
||||
|
||||
|
||||
class ResourceProtector(_ResourceProtector):
|
||||
def __init__(self, client_model, token_model):
|
||||
self.client_model = client_model
|
||||
self.token_model = token_model
|
||||
|
||||
config = getattr(settings, 'AUTHLIB_OAUTH1_PROVIDER', {})
|
||||
methods = config.get('signature_methods', [])
|
||||
if methods and isinstance(methods, (list, tuple)):
|
||||
self.SUPPORTED_SIGNATURE_METHODS = methods
|
||||
|
||||
self._nonce_expires_in = config.get('nonce_expires_in', 86400)
|
||||
|
||||
def get_client_by_id(self, client_id):
|
||||
try:
|
||||
return self.client_model.objects.get(client_id=client_id)
|
||||
except self.client_model.DoesNotExist:
|
||||
return None
|
||||
|
||||
def get_token_credential(self, request):
|
||||
try:
|
||||
return self.token_model.objects.get(
|
||||
client_id=request.client_id,
|
||||
oauth_token=request.token
|
||||
)
|
||||
except self.token_model.DoesNotExist:
|
||||
return None
|
||||
|
||||
def exists_nonce(self, nonce, request):
|
||||
return exists_nonce_in_cache(nonce, request, self._nonce_expires_in)
|
||||
|
||||
def acquire_credential(self, request):
|
||||
if request.method in ['POST', 'PUT']:
|
||||
body = request.POST.dict()
|
||||
else:
|
||||
body = None
|
||||
|
||||
url = request.build_absolute_uri()
|
||||
req = self.validate_request(request.method, url, body, request.headers)
|
||||
return req.credential
|
||||
|
||||
def __call__(self, realm=None):
|
||||
def wrapper(f):
|
||||
@functools.wraps(f)
|
||||
def decorated(request, *args, **kwargs):
|
||||
try:
|
||||
credential = self.acquire_credential(request)
|
||||
request.oauth1_credential = credential
|
||||
except OAuth1Error as error:
|
||||
body = dict(error.get_body())
|
||||
resp = JsonResponse(body, status=error.status_code)
|
||||
resp['Cache-Control'] = 'no-store'
|
||||
resp['Pragma'] = 'no-cache'
|
||||
return resp
|
||||
return f(request, *args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
@@ -0,0 +1,10 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .authorization_server import AuthorizationServer
|
||||
from .resource_protector import ResourceProtector, BearerTokenValidator
|
||||
from .endpoints import RevocationEndpoint
|
||||
from .signals import (
|
||||
client_authenticated,
|
||||
token_authenticated,
|
||||
token_revoked
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,118 @@
|
||||
from django.http import HttpResponse
|
||||
from django.utils.module_loading import import_string
|
||||
from django.conf import settings
|
||||
from authlib.oauth2 import (
|
||||
AuthorizationServer as _AuthorizationServer,
|
||||
)
|
||||
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
||||
from authlib.common.security import generate_token as _generate_token
|
||||
from authlib.common.encoding import json_dumps
|
||||
from .requests import DjangoOAuth2Request, DjangoJsonRequest
|
||||
from .signals import client_authenticated, token_revoked
|
||||
|
||||
|
||||
class AuthorizationServer(_AuthorizationServer):
|
||||
"""Django implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
|
||||
Initialize it with client model and token model::
|
||||
|
||||
from authlib.integrations.django_oauth2 import AuthorizationServer
|
||||
from your_project.models import OAuth2Client, OAuth2Token
|
||||
|
||||
server = AuthorizationServer(OAuth2Client, OAuth2Token)
|
||||
"""
|
||||
|
||||
def __init__(self, client_model, token_model):
|
||||
self.config = getattr(settings, 'AUTHLIB_OAUTH2_PROVIDER', {})
|
||||
self.client_model = client_model
|
||||
self.token_model = token_model
|
||||
scopes_supported = self.config.get('scopes_supported')
|
||||
super().__init__(scopes_supported=scopes_supported)
|
||||
# add default token generator
|
||||
self.register_token_generator('default', self.create_bearer_token_generator())
|
||||
|
||||
def query_client(self, client_id):
|
||||
"""Default method for ``AuthorizationServer.query_client``. Developers MAY
|
||||
rewrite this function to meet their own needs.
|
||||
"""
|
||||
try:
|
||||
return self.client_model.objects.get(client_id=client_id)
|
||||
except self.client_model.DoesNotExist:
|
||||
return None
|
||||
|
||||
def save_token(self, token, request):
|
||||
"""Default method for ``AuthorizationServer.save_token``. Developers MAY
|
||||
rewrite this function to meet their own needs.
|
||||
"""
|
||||
client = request.client
|
||||
if request.user:
|
||||
user_id = request.user.pk
|
||||
else:
|
||||
user_id = client.user_id
|
||||
item = self.token_model(
|
||||
client_id=client.client_id,
|
||||
user_id=user_id,
|
||||
**token
|
||||
)
|
||||
item.save()
|
||||
return item
|
||||
|
||||
def create_oauth2_request(self, request):
|
||||
return DjangoOAuth2Request(request)
|
||||
|
||||
def create_json_request(self, request):
|
||||
return DjangoJsonRequest(request)
|
||||
|
||||
def handle_response(self, status_code, payload, headers):
|
||||
if isinstance(payload, dict):
|
||||
payload = json_dumps(payload)
|
||||
resp = HttpResponse(payload, status=status_code)
|
||||
for k, v in headers:
|
||||
resp[k] = v
|
||||
return resp
|
||||
|
||||
def send_signal(self, name, *args, **kwargs):
|
||||
if name == 'after_authenticate_client':
|
||||
client_authenticated.send(sender=self.__class__, *args, **kwargs)
|
||||
elif name == 'after_revoke_token':
|
||||
token_revoked.send(sender=self.__class__, *args, **kwargs)
|
||||
|
||||
def create_bearer_token_generator(self):
|
||||
"""Default method to create BearerToken generator."""
|
||||
conf = self.config.get('access_token_generator', True)
|
||||
access_token_generator = create_token_generator(conf, 42)
|
||||
|
||||
conf = self.config.get('refresh_token_generator', False)
|
||||
refresh_token_generator = create_token_generator(conf, 48)
|
||||
|
||||
conf = self.config.get('token_expires_in')
|
||||
expires_generator = create_token_expires_in_generator(conf)
|
||||
|
||||
return BearerTokenGenerator(
|
||||
access_token_generator=access_token_generator,
|
||||
refresh_token_generator=refresh_token_generator,
|
||||
expires_generator=expires_generator,
|
||||
)
|
||||
|
||||
|
||||
def create_token_generator(token_generator_conf, length=42):
|
||||
if callable(token_generator_conf):
|
||||
return token_generator_conf
|
||||
|
||||
if isinstance(token_generator_conf, str):
|
||||
return import_string(token_generator_conf)
|
||||
elif token_generator_conf is True:
|
||||
def token_generator(*args, **kwargs):
|
||||
return _generate_token(length)
|
||||
return token_generator
|
||||
|
||||
|
||||
def create_token_expires_in_generator(expires_in_conf=None):
|
||||
data = {}
|
||||
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
|
||||
if expires_in_conf:
|
||||
data.update(expires_in_conf)
|
||||
|
||||
def expires_in(client, grant_type):
|
||||
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
|
||||
|
||||
return expires_in
|
||||
@@ -0,0 +1,56 @@
|
||||
from authlib.oauth2.rfc7009 import RevocationEndpoint as _RevocationEndpoint
|
||||
|
||||
|
||||
class RevocationEndpoint(_RevocationEndpoint):
|
||||
"""The revocation endpoint for OAuth authorization servers allows clients
|
||||
to notify the authorization server that a previously obtained refresh or
|
||||
access token is no longer needed.
|
||||
|
||||
Register it into authorization server, and create token endpoint response
|
||||
for token revocation::
|
||||
|
||||
from django.views.decorators.http import require_http_methods
|
||||
|
||||
# see register into authorization server instance
|
||||
server.register_endpoint(RevocationEndpoint)
|
||||
|
||||
@require_http_methods(["POST"])
|
||||
def revoke_token(request):
|
||||
return server.create_endpoint_response(
|
||||
RevocationEndpoint.ENDPOINT_NAME,
|
||||
request
|
||||
)
|
||||
"""
|
||||
|
||||
def query_token(self, token, token_type_hint):
|
||||
"""Query requested token from database."""
|
||||
token_model = self.server.token_model
|
||||
if token_type_hint == 'access_token':
|
||||
rv = _query_access_token(token_model, token)
|
||||
elif token_type_hint == 'refresh_token':
|
||||
rv = _query_refresh_token(token_model, token)
|
||||
else:
|
||||
rv = _query_access_token(token_model, token)
|
||||
if not rv:
|
||||
rv = _query_refresh_token(token_model, token)
|
||||
|
||||
return rv
|
||||
|
||||
def revoke_token(self, token, request):
|
||||
"""Mark the give token as revoked."""
|
||||
token.revoked = True
|
||||
token.save()
|
||||
|
||||
|
||||
def _query_access_token(token_model, token):
|
||||
try:
|
||||
return token_model.objects.get(access_token=token)
|
||||
except token_model.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def _query_refresh_token(token_model, token):
|
||||
try:
|
||||
return token_model.objects.get(refresh_token=token)
|
||||
except token_model.DoesNotExist:
|
||||
return None
|
||||
@@ -0,0 +1,46 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from django.http import HttpRequest
|
||||
from django.utils.functional import cached_property
|
||||
from authlib.common.encoding import json_loads
|
||||
from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest
|
||||
|
||||
|
||||
class DjangoOAuth2Request(OAuth2Request):
|
||||
def __init__(self, request: HttpRequest):
|
||||
super().__init__(request.method, request.build_absolute_uri(), None, request.headers)
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._request.GET
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
return self._request.POST
|
||||
|
||||
@cached_property
|
||||
def data(self):
|
||||
data = {}
|
||||
data.update(self._request.GET.dict())
|
||||
data.update(self._request.POST.dict())
|
||||
return data
|
||||
|
||||
@cached_property
|
||||
def datalist(self):
|
||||
values = defaultdict(list)
|
||||
for k in self.args:
|
||||
values[k].extend(self.args.getlist(k))
|
||||
for k in self.form:
|
||||
values[k].extend(self.form.getlist(k))
|
||||
return values
|
||||
|
||||
|
||||
class DjangoJsonRequest(JsonRequest):
|
||||
def __init__(self, request: HttpRequest):
|
||||
super().__init__(request.method, request.build_absolute_uri(), None, request.headers)
|
||||
self._request = request
|
||||
|
||||
@cached_property
|
||||
def data(self):
|
||||
return json_loads(self._request.body)
|
||||
@@ -0,0 +1,75 @@
|
||||
import functools
|
||||
from django.http import JsonResponse
|
||||
from authlib.oauth2 import (
|
||||
OAuth2Error,
|
||||
ResourceProtector as _ResourceProtector,
|
||||
)
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
MissingAuthorizationError,
|
||||
)
|
||||
from authlib.oauth2.rfc6750 import (
|
||||
BearerTokenValidator as _BearerTokenValidator
|
||||
)
|
||||
from .requests import DjangoJsonRequest
|
||||
from .signals import token_authenticated
|
||||
|
||||
|
||||
class ResourceProtector(_ResourceProtector):
|
||||
def acquire_token(self, request, scopes=None, **kwargs):
|
||||
"""A method to acquire current valid token with the given scope.
|
||||
|
||||
:param request: Django HTTP request instance
|
||||
:param scopes: a list of scope values
|
||||
:return: token object
|
||||
"""
|
||||
req = DjangoJsonRequest(request)
|
||||
# backward compatibility
|
||||
kwargs['scopes'] = scopes
|
||||
for claim in kwargs:
|
||||
if isinstance(kwargs[claim], str):
|
||||
kwargs[claim] = [kwargs[claim]]
|
||||
token = self.validate_request(request=req, **kwargs)
|
||||
token_authenticated.send(sender=self.__class__, token=token)
|
||||
return token
|
||||
|
||||
def __call__(self, scopes=None, optional=False, **kwargs):
|
||||
claims = kwargs
|
||||
# backward compatibility
|
||||
claims['scopes'] = scopes
|
||||
def wrapper(f):
|
||||
@functools.wraps(f)
|
||||
def decorated(request, *args, **kwargs):
|
||||
try:
|
||||
token = self.acquire_token(request, **claims)
|
||||
request.oauth_token = token
|
||||
except MissingAuthorizationError as error:
|
||||
if optional:
|
||||
request.oauth_token = None
|
||||
return f(request, *args, **kwargs)
|
||||
return return_error_response(error)
|
||||
except OAuth2Error as error:
|
||||
return return_error_response(error)
|
||||
return f(request, *args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
||||
|
||||
class BearerTokenValidator(_BearerTokenValidator):
|
||||
def __init__(self, token_model, realm=None, **extra_attributes):
|
||||
self.token_model = token_model
|
||||
super().__init__(realm, **extra_attributes)
|
||||
|
||||
def authenticate_token(self, token_string):
|
||||
try:
|
||||
return self.token_model.objects.get(access_token=token_string)
|
||||
except self.token_model.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def return_error_response(error):
|
||||
body = dict(error.get_body())
|
||||
resp = JsonResponse(body, status=error.status_code)
|
||||
headers = error.get_headers()
|
||||
for k, v in headers:
|
||||
resp[k] = v
|
||||
return resp
|
||||
@@ -0,0 +1,11 @@
|
||||
from django.dispatch import Signal
|
||||
|
||||
|
||||
#: signal when client is authenticated
|
||||
client_authenticated = Signal()
|
||||
|
||||
#: signal when token is revoked
|
||||
token_revoked = Signal()
|
||||
|
||||
#: signal when token is authenticated
|
||||
token_authenticated = Signal()
|
||||
@@ -0,0 +1,51 @@
|
||||
from werkzeug.local import LocalProxy
|
||||
from .integration import FlaskIntegration, token_update
|
||||
from .apps import FlaskOAuth1App, FlaskOAuth2App
|
||||
from ..base_client import BaseOAuth, OAuthError
|
||||
|
||||
|
||||
class OAuth(BaseOAuth):
|
||||
oauth1_client_cls = FlaskOAuth1App
|
||||
oauth2_client_cls = FlaskOAuth2App
|
||||
framework_integration_cls = FlaskIntegration
|
||||
|
||||
def __init__(self, app=None, cache=None, fetch_token=None, update_token=None):
|
||||
super().__init__(
|
||||
cache=cache, fetch_token=fetch_token, update_token=update_token)
|
||||
self.app = app
|
||||
if app:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app, cache=None, fetch_token=None, update_token=None):
|
||||
"""Initialize lazy for Flask app. This is usually used for Flask application
|
||||
factory pattern.
|
||||
"""
|
||||
self.app = app
|
||||
if cache is not None:
|
||||
self.cache = cache
|
||||
|
||||
if fetch_token:
|
||||
self.fetch_token = fetch_token
|
||||
if update_token:
|
||||
self.update_token = update_token
|
||||
|
||||
app.extensions = getattr(app, 'extensions', {})
|
||||
app.extensions['authlib.integrations.flask_client'] = self
|
||||
|
||||
def create_client(self, name):
|
||||
if not self.app:
|
||||
raise RuntimeError('OAuth is not init with Flask app.')
|
||||
return super().create_client(name)
|
||||
|
||||
def register(self, name, overwrite=False, **kwargs):
|
||||
self._registry[name] = (overwrite, kwargs)
|
||||
if self.app:
|
||||
return self.create_client(name)
|
||||
return LocalProxy(lambda: self.create_client(name))
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuth', 'FlaskIntegration',
|
||||
'FlaskOAuth1App', 'FlaskOAuth2App',
|
||||
'token_update', 'OAuthError',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,107 @@
|
||||
from flask import g, redirect, request, session
|
||||
from ..requests_client import OAuth1Session, OAuth2Session
|
||||
from ..base_client import (
|
||||
BaseApp, OAuthError,
|
||||
OAuth1Mixin, OAuth2Mixin, OpenIDMixin,
|
||||
)
|
||||
|
||||
|
||||
class FlaskAppMixin:
|
||||
@property
|
||||
def token(self):
|
||||
attr = f'_oauth_token_{self.name}'
|
||||
token = g.get(attr)
|
||||
if token:
|
||||
return token
|
||||
if self._fetch_token:
|
||||
token = self._fetch_token()
|
||||
self.token = token
|
||||
return token
|
||||
|
||||
@token.setter
|
||||
def token(self, token):
|
||||
attr = f'_oauth_token_{self.name}'
|
||||
setattr(g, attr, token)
|
||||
|
||||
def _get_requested_token(self, *args, **kwargs):
|
||||
return self.token
|
||||
|
||||
def save_authorize_data(self, **kwargs):
|
||||
state = kwargs.pop('state', None)
|
||||
if state:
|
||||
self.framework.set_state_data(session, state, kwargs)
|
||||
else:
|
||||
raise RuntimeError('Missing state value')
|
||||
|
||||
def authorize_redirect(self, redirect_uri=None, **kwargs):
|
||||
"""Create a HTTP Redirect for Authorization Endpoint.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: A HTTP redirect response.
|
||||
"""
|
||||
rv = self.create_authorization_url(redirect_uri, **kwargs)
|
||||
self.save_authorize_data(redirect_uri=redirect_uri, **rv)
|
||||
return redirect(rv['url'])
|
||||
|
||||
|
||||
class FlaskOAuth1App(FlaskAppMixin, OAuth1Mixin, BaseApp):
|
||||
client_cls = OAuth1Session
|
||||
|
||||
def authorize_access_token(self, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:return: A token dict.
|
||||
"""
|
||||
params = request.args.to_dict(flat=True)
|
||||
state = params.get('oauth_token')
|
||||
if not state:
|
||||
raise OAuthError(description='Missing "oauth_token" parameter')
|
||||
|
||||
data = self.framework.get_state_data(session, state)
|
||||
if not data:
|
||||
raise OAuthError(description='Missing "request_token" in temporary data')
|
||||
|
||||
params['request_token'] = data['request_token']
|
||||
params.update(kwargs)
|
||||
self.framework.clear_state_data(session, state)
|
||||
token = self.fetch_access_token(**params)
|
||||
self.token = token
|
||||
return token
|
||||
|
||||
|
||||
class FlaskOAuth2App(FlaskAppMixin, OAuth2Mixin, OpenIDMixin, BaseApp):
|
||||
client_cls = OAuth2Session
|
||||
|
||||
def authorize_access_token(self, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:return: A token dict.
|
||||
"""
|
||||
if request.method == 'GET':
|
||||
error = request.args.get('error')
|
||||
if error:
|
||||
description = request.args.get('error_description')
|
||||
raise OAuthError(error=error, description=description)
|
||||
|
||||
params = {
|
||||
'code': request.args.get('code'),
|
||||
'state': request.args.get('state'),
|
||||
}
|
||||
else:
|
||||
params = {
|
||||
'code': request.form.get('code'),
|
||||
'state': request.form.get('state'),
|
||||
}
|
||||
|
||||
claims_options = kwargs.pop('claims_options', None)
|
||||
state_data = self.framework.get_state_data(session, params.get('state'))
|
||||
self.framework.clear_state_data(session, params.get('state'))
|
||||
params = self._format_state_params(state_data, params)
|
||||
token = self.fetch_access_token(**params, **kwargs)
|
||||
self.token = token
|
||||
|
||||
if 'id_token' in token and 'nonce' in state_data:
|
||||
userinfo = self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
|
||||
token['userinfo'] = userinfo
|
||||
return token
|
||||
@@ -0,0 +1,28 @@
|
||||
from flask import current_app
|
||||
from flask.signals import Namespace
|
||||
from ..base_client import FrameworkIntegration
|
||||
|
||||
_signal = Namespace()
|
||||
#: signal when token is updated
|
||||
token_update = _signal.signal('token_update')
|
||||
|
||||
|
||||
class FlaskIntegration(FrameworkIntegration):
|
||||
def update_token(self, token, refresh_token=None, access_token=None):
|
||||
token_update.send(
|
||||
current_app,
|
||||
name=self.name,
|
||||
token=token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def load_config(oauth, name, params):
|
||||
rv = {}
|
||||
for k in params:
|
||||
conf_key = f'{name}_{k}'.upper()
|
||||
v = oauth.app.config.get(conf_key, None)
|
||||
if v is not None:
|
||||
rv[k] = v
|
||||
return rv
|
||||
@@ -0,0 +1,9 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .authorization_server import AuthorizationServer
|
||||
from .resource_protector import ResourceProtector, current_credential
|
||||
from .cache import (
|
||||
register_nonce_hooks,
|
||||
register_temporary_credential_hooks,
|
||||
create_exists_nonce_func,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,182 @@
|
||||
import logging
|
||||
from werkzeug.utils import import_string
|
||||
from flask import Response
|
||||
from flask import request as flask_req
|
||||
from authlib.oauth1 import (
|
||||
OAuth1Request,
|
||||
AuthorizationServer as _AuthorizationServer,
|
||||
)
|
||||
from authlib.common.security import generate_token
|
||||
from authlib.common.urls import url_encode
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationServer(_AuthorizationServer):
|
||||
"""Flask implementation of :class:`authlib.rfc5849.AuthorizationServer`.
|
||||
Initialize it with Flask app instance, client model class and cache::
|
||||
|
||||
server = AuthorizationServer(app=app, query_client=query_client)
|
||||
# or initialize lazily
|
||||
server = AuthorizationServer()
|
||||
server.init_app(app, query_client=query_client)
|
||||
|
||||
:param app: A Flask app instance
|
||||
:param query_client: A function to get client by client_id. The client
|
||||
model class MUST implement the methods described by
|
||||
:class:`~authlib.oauth1.rfc5849.ClientMixin`.
|
||||
:param token_generator: A function to generate token
|
||||
"""
|
||||
|
||||
def __init__(self, app=None, query_client=None, token_generator=None):
|
||||
self.app = app
|
||||
self.query_client = query_client
|
||||
self.token_generator = token_generator
|
||||
|
||||
self._hooks = {
|
||||
'exists_nonce': None,
|
||||
'create_temporary_credential': None,
|
||||
'get_temporary_credential': None,
|
||||
'delete_temporary_credential': None,
|
||||
'create_authorization_verifier': None,
|
||||
'create_token_credential': None,
|
||||
}
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app, query_client=None, token_generator=None):
|
||||
if query_client is not None:
|
||||
self.query_client = query_client
|
||||
if token_generator is not None:
|
||||
self.token_generator = token_generator
|
||||
|
||||
if self.token_generator is None:
|
||||
self.token_generator = self.create_token_generator(app)
|
||||
|
||||
methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS')
|
||||
if methods and isinstance(methods, (list, tuple)):
|
||||
self.SUPPORTED_SIGNATURE_METHODS = methods
|
||||
|
||||
self.app = app
|
||||
|
||||
def register_hook(self, name, func):
|
||||
if name not in self._hooks:
|
||||
raise ValueError('Invalid "name" of hook')
|
||||
self._hooks[name] = func
|
||||
|
||||
def create_token_generator(self, app):
|
||||
token_generator = app.config.get('OAUTH1_TOKEN_GENERATOR')
|
||||
|
||||
if isinstance(token_generator, str):
|
||||
token_generator = import_string(token_generator)
|
||||
else:
|
||||
length = app.config.get('OAUTH1_TOKEN_LENGTH', 42)
|
||||
|
||||
def token_generator():
|
||||
return generate_token(length)
|
||||
|
||||
secret_generator = app.config.get('OAUTH1_TOKEN_SECRET_GENERATOR')
|
||||
if isinstance(secret_generator, str):
|
||||
secret_generator = import_string(secret_generator)
|
||||
else:
|
||||
length = app.config.get('OAUTH1_TOKEN_SECRET_LENGTH', 48)
|
||||
|
||||
def secret_generator():
|
||||
return generate_token(length)
|
||||
|
||||
def create_token():
|
||||
return {
|
||||
'oauth_token': token_generator(),
|
||||
'oauth_token_secret': secret_generator()
|
||||
}
|
||||
return create_token
|
||||
|
||||
def get_client_by_id(self, client_id):
|
||||
return self.query_client(client_id)
|
||||
|
||||
def exists_nonce(self, nonce, request):
|
||||
func = self._hooks['exists_nonce']
|
||||
if callable(func):
|
||||
timestamp = request.timestamp
|
||||
client_id = request.client_id
|
||||
token = request.token
|
||||
return func(nonce, timestamp, client_id, token)
|
||||
|
||||
raise RuntimeError('"exists_nonce" hook is required.')
|
||||
|
||||
def create_temporary_credential(self, request):
|
||||
func = self._hooks['create_temporary_credential']
|
||||
if callable(func):
|
||||
token = self.token_generator()
|
||||
return func(token, request.client_id, request.redirect_uri)
|
||||
raise RuntimeError(
|
||||
'"create_temporary_credential" hook is required.'
|
||||
)
|
||||
|
||||
def get_temporary_credential(self, request):
|
||||
func = self._hooks['get_temporary_credential']
|
||||
if callable(func):
|
||||
return func(request.token)
|
||||
|
||||
raise RuntimeError(
|
||||
'"get_temporary_credential" hook is required.'
|
||||
)
|
||||
|
||||
def delete_temporary_credential(self, request):
|
||||
func = self._hooks['delete_temporary_credential']
|
||||
if callable(func):
|
||||
return func(request.token)
|
||||
|
||||
raise RuntimeError(
|
||||
'"delete_temporary_credential" hook is required.'
|
||||
)
|
||||
|
||||
def create_authorization_verifier(self, request):
|
||||
func = self._hooks['create_authorization_verifier']
|
||||
if callable(func):
|
||||
verifier = generate_token(36)
|
||||
func(request.credential, request.user, verifier)
|
||||
return verifier
|
||||
|
||||
raise RuntimeError(
|
||||
'"create_authorization_verifier" hook is required.'
|
||||
)
|
||||
|
||||
def create_token_credential(self, request):
|
||||
func = self._hooks['create_token_credential']
|
||||
if callable(func):
|
||||
temporary_credential = request.credential
|
||||
token = self.token_generator()
|
||||
return func(token, temporary_credential)
|
||||
|
||||
raise RuntimeError(
|
||||
'"create_token_credential" hook is required.'
|
||||
)
|
||||
|
||||
def check_authorization_request(self):
|
||||
req = self.create_oauth1_request(None)
|
||||
self.validate_authorization_request(req)
|
||||
return req
|
||||
|
||||
def create_authorization_response(self, request=None, grant_user=None):
|
||||
return super()\
|
||||
.create_authorization_response(request, grant_user)
|
||||
|
||||
def create_token_response(self, request=None):
|
||||
return super().create_token_response(request)
|
||||
|
||||
def create_oauth1_request(self, request):
|
||||
if request is None:
|
||||
request = flask_req
|
||||
if request.method in ('POST', 'PUT'):
|
||||
body = request.form.to_dict(flat=True)
|
||||
else:
|
||||
body = None
|
||||
return OAuth1Request(request.method, request.url, body, request.headers)
|
||||
|
||||
def handle_response(self, status_code, payload, headers):
|
||||
return Response(
|
||||
url_encode(payload),
|
||||
status=status_code,
|
||||
headers=headers
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
from authlib.oauth1 import TemporaryCredential
|
||||
|
||||
|
||||
def register_temporary_credential_hooks(
|
||||
authorization_server, cache, key_prefix='temporary_credential:'):
|
||||
"""Register temporary credential related hooks to authorization server.
|
||||
|
||||
:param authorization_server: AuthorizationServer instance
|
||||
:param cache: Cache instance
|
||||
:param key_prefix: key prefix for temporary credential
|
||||
"""
|
||||
|
||||
def create_temporary_credential(token, client_id, redirect_uri):
|
||||
key = key_prefix + token['oauth_token']
|
||||
token['client_id'] = client_id
|
||||
if redirect_uri:
|
||||
token['oauth_callback'] = redirect_uri
|
||||
|
||||
cache.set(key, token, timeout=86400) # cache for one day
|
||||
return TemporaryCredential(token)
|
||||
|
||||
def get_temporary_credential(oauth_token):
|
||||
if not oauth_token:
|
||||
return None
|
||||
key = key_prefix + oauth_token
|
||||
value = cache.get(key)
|
||||
if value:
|
||||
return TemporaryCredential(value)
|
||||
|
||||
def delete_temporary_credential(oauth_token):
|
||||
if oauth_token:
|
||||
key = key_prefix + oauth_token
|
||||
cache.delete(key)
|
||||
|
||||
def create_authorization_verifier(credential, grant_user, verifier):
|
||||
key = key_prefix + credential.get_oauth_token()
|
||||
credential['oauth_verifier'] = verifier
|
||||
credential['user_id'] = grant_user.get_user_id()
|
||||
cache.set(key, credential, timeout=86400)
|
||||
return credential
|
||||
|
||||
authorization_server.register_hook(
|
||||
'create_temporary_credential', create_temporary_credential)
|
||||
authorization_server.register_hook(
|
||||
'get_temporary_credential', get_temporary_credential)
|
||||
authorization_server.register_hook(
|
||||
'delete_temporary_credential', delete_temporary_credential)
|
||||
authorization_server.register_hook(
|
||||
'create_authorization_verifier', create_authorization_verifier)
|
||||
|
||||
|
||||
def create_exists_nonce_func(cache, key_prefix='nonce:', expires=86400):
|
||||
"""Create an ``exists_nonce`` function that can be used in hooks and
|
||||
resource protector.
|
||||
|
||||
:param cache: Cache instance
|
||||
:param key_prefix: key prefix for temporary credential
|
||||
:param expires: Expire time for nonce
|
||||
"""
|
||||
def exists_nonce(nonce, timestamp, client_id, oauth_token):
|
||||
key = f'{key_prefix}{nonce}-{timestamp}-{client_id}'
|
||||
if oauth_token:
|
||||
key = f'{key}-{oauth_token}'
|
||||
rv = cache.has(key)
|
||||
cache.set(key, 1, timeout=expires)
|
||||
return rv
|
||||
return exists_nonce
|
||||
|
||||
|
||||
def register_nonce_hooks(
|
||||
authorization_server, cache, key_prefix='nonce:', expires=86400):
|
||||
"""Register nonce related hooks to authorization server.
|
||||
|
||||
:param authorization_server: AuthorizationServer instance
|
||||
:param cache: Cache instance
|
||||
:param key_prefix: key prefix for temporary credential
|
||||
:param expires: Expire time for nonce
|
||||
"""
|
||||
exists_nonce = create_exists_nonce_func(cache, key_prefix, expires)
|
||||
authorization_server.register_hook('exists_nonce', exists_nonce)
|
||||
@@ -0,0 +1,113 @@
|
||||
import functools
|
||||
from flask import g, json, Response
|
||||
from flask import request as _req
|
||||
from werkzeug.local import LocalProxy
|
||||
from authlib.consts import default_json_headers
|
||||
from authlib.oauth1 import ResourceProtector as _ResourceProtector
|
||||
from authlib.oauth1.errors import OAuth1Error
|
||||
|
||||
|
||||
class ResourceProtector(_ResourceProtector):
|
||||
"""A protecting method for resource servers. Initialize a resource
|
||||
protector with the these method:
|
||||
|
||||
1. query_client
|
||||
2. query_token,
|
||||
3. exists_nonce
|
||||
|
||||
Usually, a ``query_client`` method would look like (if using SQLAlchemy)::
|
||||
|
||||
def query_client(client_id):
|
||||
return Client.query.filter_by(client_id=client_id).first()
|
||||
|
||||
A ``query_token`` method accept two parameters, ``client_id`` and ``oauth_token``::
|
||||
|
||||
def query_token(client_id, oauth_token):
|
||||
return Token.query.filter_by(client_id=client_id, oauth_token=oauth_token).first()
|
||||
|
||||
And for ``exists_nonce``, if using cache, we have a built-in hook to create this method::
|
||||
|
||||
from authlib.integrations.flask_oauth1 import create_exists_nonce_func
|
||||
|
||||
exists_nonce = create_exists_nonce_func(cache)
|
||||
|
||||
Then initialize the resource protector with those methods::
|
||||
|
||||
require_oauth = ResourceProtector(
|
||||
app, query_client=query_client,
|
||||
query_token=query_token, exists_nonce=exists_nonce,
|
||||
)
|
||||
"""
|
||||
def __init__(self, app=None, query_client=None,
|
||||
query_token=None, exists_nonce=None):
|
||||
self.query_client = query_client
|
||||
self.query_token = query_token
|
||||
self._exists_nonce = exists_nonce
|
||||
|
||||
self.app = app
|
||||
if app:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app, query_client=None, query_token=None,
|
||||
exists_nonce=None):
|
||||
if query_client is not None:
|
||||
self.query_client = query_client
|
||||
if query_token is not None:
|
||||
self.query_token = query_token
|
||||
if exists_nonce is not None:
|
||||
self._exists_nonce = exists_nonce
|
||||
|
||||
methods = app.config.get('OAUTH1_SUPPORTED_SIGNATURE_METHODS')
|
||||
if methods and isinstance(methods, (list, tuple)):
|
||||
self.SUPPORTED_SIGNATURE_METHODS = methods
|
||||
|
||||
self.app = app
|
||||
|
||||
def get_client_by_id(self, client_id):
|
||||
return self.query_client(client_id)
|
||||
|
||||
def get_token_credential(self, request):
|
||||
return self.query_token(request.client_id, request.token)
|
||||
|
||||
def exists_nonce(self, nonce, request):
|
||||
if not self._exists_nonce:
|
||||
raise RuntimeError('"exists_nonce" function is required.')
|
||||
|
||||
timestamp = request.timestamp
|
||||
client_id = request.client_id
|
||||
token = request.token
|
||||
return self._exists_nonce(nonce, timestamp, client_id, token)
|
||||
|
||||
def acquire_credential(self):
|
||||
req = self.validate_request(
|
||||
_req.method,
|
||||
_req.url,
|
||||
_req.form.to_dict(flat=True),
|
||||
_req.headers
|
||||
)
|
||||
g.authlib_server_oauth1_credential = req.credential
|
||||
return req.credential
|
||||
|
||||
def __call__(self, scope=None):
|
||||
def wrapper(f):
|
||||
@functools.wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
self.acquire_credential()
|
||||
except OAuth1Error as error:
|
||||
body = dict(error.get_body())
|
||||
return Response(
|
||||
json.dumps(body),
|
||||
status=error.status_code,
|
||||
headers=default_json_headers,
|
||||
)
|
||||
return f(*args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_current_credential():
|
||||
return g.get('authlib_server_oauth1_credential')
|
||||
|
||||
|
||||
current_credential = LocalProxy(_get_current_credential)
|
||||
@@ -0,0 +1,12 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .authorization_server import AuthorizationServer
|
||||
from .resource_protector import (
|
||||
ResourceProtector,
|
||||
current_token,
|
||||
)
|
||||
from .signals import (
|
||||
client_authenticated,
|
||||
token_authenticated,
|
||||
token_revoked,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,159 @@
|
||||
from werkzeug.utils import import_string
|
||||
from flask import Response, json
|
||||
from flask import request as flask_req
|
||||
from authlib.oauth2 import (
|
||||
AuthorizationServer as _AuthorizationServer,
|
||||
)
|
||||
from authlib.oauth2.rfc6750 import BearerTokenGenerator
|
||||
from authlib.common.security import generate_token
|
||||
from .requests import FlaskOAuth2Request, FlaskJsonRequest
|
||||
from .signals import client_authenticated, token_revoked
|
||||
|
||||
|
||||
class AuthorizationServer(_AuthorizationServer):
|
||||
"""Flask implementation of :class:`authlib.oauth2.rfc6749.AuthorizationServer`.
|
||||
Initialize it with ``query_client``, ``save_token`` methods and Flask
|
||||
app instance::
|
||||
|
||||
def query_client(client_id):
|
||||
return Client.query.filter_by(client_id=client_id).first()
|
||||
|
||||
def save_token(token, request):
|
||||
if request.user:
|
||||
user_id = request.user.id
|
||||
else:
|
||||
user_id = None
|
||||
client = request.client
|
||||
tok = Token(
|
||||
client_id=client.client_id,
|
||||
user_id=user.id,
|
||||
**token
|
||||
)
|
||||
db.session.add(tok)
|
||||
db.session.commit()
|
||||
|
||||
server = AuthorizationServer(app, query_client, save_token)
|
||||
# or initialize lazily
|
||||
server = AuthorizationServer()
|
||||
server.init_app(app, query_client, save_token)
|
||||
"""
|
||||
|
||||
def __init__(self, app=None, query_client=None, save_token=None):
|
||||
super().__init__()
|
||||
self._query_client = query_client
|
||||
self._save_token = save_token
|
||||
self._error_uris = None
|
||||
if app is not None:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app, query_client=None, save_token=None):
|
||||
"""Initialize later with Flask app instance."""
|
||||
if query_client is not None:
|
||||
self._query_client = query_client
|
||||
if save_token is not None:
|
||||
self._save_token = save_token
|
||||
|
||||
self.register_token_generator('default', self.create_bearer_token_generator(app.config))
|
||||
self.scopes_supported = app.config.get('OAUTH2_SCOPES_SUPPORTED')
|
||||
self._error_uris = app.config.get('OAUTH2_ERROR_URIS')
|
||||
|
||||
def query_client(self, client_id):
|
||||
return self._query_client(client_id)
|
||||
|
||||
def save_token(self, token, request):
|
||||
return self._save_token(token, request)
|
||||
|
||||
def get_error_uri(self, request, error):
|
||||
if self._error_uris:
|
||||
uris = dict(self._error_uris)
|
||||
return uris.get(error.error)
|
||||
|
||||
def create_oauth2_request(self, request):
|
||||
return FlaskOAuth2Request(flask_req)
|
||||
|
||||
def create_json_request(self, request):
|
||||
return FlaskJsonRequest(flask_req)
|
||||
|
||||
def handle_response(self, status_code, payload, headers):
|
||||
if isinstance(payload, dict):
|
||||
payload = json.dumps(payload)
|
||||
return Response(payload, status=status_code, headers=headers)
|
||||
|
||||
def send_signal(self, name, *args, **kwargs):
|
||||
if name == 'after_authenticate_client':
|
||||
client_authenticated.send(self, *args, **kwargs)
|
||||
elif name == 'after_revoke_token':
|
||||
token_revoked.send(self, *args, **kwargs)
|
||||
|
||||
def create_bearer_token_generator(self, config):
|
||||
"""Create a generator function for generating ``token`` value. This
|
||||
method will create a Bearer Token generator with
|
||||
:class:`authlib.oauth2.rfc6750.BearerToken`.
|
||||
|
||||
Configurable settings:
|
||||
|
||||
1. OAUTH2_ACCESS_TOKEN_GENERATOR: Boolean or import string, default is True.
|
||||
2. OAUTH2_REFRESH_TOKEN_GENERATOR: Boolean or import string, default is False.
|
||||
3. OAUTH2_TOKEN_EXPIRES_IN: Dict or import string, default is None.
|
||||
|
||||
By default, it will not generate ``refresh_token``, which can be turn on by
|
||||
configure ``OAUTH2_REFRESH_TOKEN_GENERATOR``.
|
||||
|
||||
Here are some examples of the token generator::
|
||||
|
||||
OAUTH2_ACCESS_TOKEN_GENERATOR = 'your_project.generators.gen_token'
|
||||
|
||||
# and in module `your_project.generators`, you can define:
|
||||
|
||||
def gen_token(client, grant_type, user, scope):
|
||||
# generate token according to these parameters
|
||||
token = create_random_token()
|
||||
return f'{client.id}-{user.id}-{token}'
|
||||
|
||||
Here is an example of ``OAUTH2_TOKEN_EXPIRES_IN``::
|
||||
|
||||
OAUTH2_TOKEN_EXPIRES_IN = {
|
||||
'authorization_code': 864000,
|
||||
'urn:ietf:params:oauth:grant-type:jwt-bearer': 3600,
|
||||
}
|
||||
"""
|
||||
conf = config.get('OAUTH2_ACCESS_TOKEN_GENERATOR', True)
|
||||
access_token_generator = create_token_generator(conf, 42)
|
||||
|
||||
conf = config.get('OAUTH2_REFRESH_TOKEN_GENERATOR', False)
|
||||
refresh_token_generator = create_token_generator(conf, 48)
|
||||
|
||||
expires_conf = config.get('OAUTH2_TOKEN_EXPIRES_IN')
|
||||
expires_generator = create_token_expires_in_generator(expires_conf)
|
||||
return BearerTokenGenerator(
|
||||
access_token_generator,
|
||||
refresh_token_generator,
|
||||
expires_generator
|
||||
)
|
||||
|
||||
|
||||
def create_token_expires_in_generator(expires_in_conf=None):
|
||||
if isinstance(expires_in_conf, str):
|
||||
return import_string(expires_in_conf)
|
||||
|
||||
data = {}
|
||||
data.update(BearerTokenGenerator.GRANT_TYPES_EXPIRES_IN)
|
||||
if isinstance(expires_in_conf, dict):
|
||||
data.update(expires_in_conf)
|
||||
|
||||
def expires_in(client, grant_type):
|
||||
return data.get(grant_type, BearerTokenGenerator.DEFAULT_EXPIRES_IN)
|
||||
|
||||
return expires_in
|
||||
|
||||
|
||||
def create_token_generator(token_generator_conf, length=42):
|
||||
if callable(token_generator_conf):
|
||||
return token_generator_conf
|
||||
|
||||
if isinstance(token_generator_conf, str):
|
||||
return import_string(token_generator_conf)
|
||||
elif token_generator_conf is True:
|
||||
def token_generator(*args, **kwargs):
|
||||
return generate_token(length)
|
||||
return token_generator
|
||||
@@ -0,0 +1,39 @@
|
||||
import importlib.metadata
|
||||
|
||||
import werkzeug
|
||||
from werkzeug.exceptions import HTTPException
|
||||
|
||||
_version = importlib.metadata.version('werkzeug').split('.')[0]
|
||||
|
||||
if _version in ('0', '1'):
|
||||
class _HTTPException(HTTPException):
|
||||
def __init__(self, code, body, headers, response=None):
|
||||
super().__init__(None, response)
|
||||
self.code = code
|
||||
|
||||
self.body = body
|
||||
self.headers = headers
|
||||
|
||||
def get_body(self, environ=None):
|
||||
return self.body
|
||||
|
||||
def get_headers(self, environ=None):
|
||||
return self.headers
|
||||
else:
|
||||
class _HTTPException(HTTPException):
|
||||
def __init__(self, code, body, headers, response=None):
|
||||
super().__init__(None, response)
|
||||
self.code = code
|
||||
|
||||
self.body = body
|
||||
self.headers = headers
|
||||
|
||||
def get_body(self, environ=None, scope=None):
|
||||
return self.body
|
||||
|
||||
def get_headers(self, environ=None, scope=None):
|
||||
return self.headers
|
||||
|
||||
|
||||
def raise_http_exception(status, body, headers):
|
||||
raise _HTTPException(status, body, headers)
|
||||
@@ -0,0 +1,40 @@
|
||||
from collections import defaultdict
|
||||
from functools import cached_property
|
||||
|
||||
from flask.wrappers import Request
|
||||
from authlib.oauth2.rfc6749 import OAuth2Request, JsonRequest
|
||||
|
||||
|
||||
class FlaskOAuth2Request(OAuth2Request):
|
||||
def __init__(self, request: Request):
|
||||
super().__init__(request.method, request.url, None, request.headers)
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def args(self):
|
||||
return self._request.args
|
||||
|
||||
@property
|
||||
def form(self):
|
||||
return self._request.form
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._request.values
|
||||
|
||||
@cached_property
|
||||
def datalist(self):
|
||||
values = defaultdict(list)
|
||||
for k in self.data:
|
||||
values[k].extend(self.data.getlist(k))
|
||||
return values
|
||||
|
||||
|
||||
class FlaskJsonRequest(JsonRequest):
|
||||
def __init__(self, request: Request):
|
||||
super().__init__(request.method, request.url, None, request.headers)
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._request.get_json()
|
||||
@@ -0,0 +1,114 @@
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from flask import g, json
|
||||
from flask import request as _req
|
||||
from werkzeug.local import LocalProxy
|
||||
from authlib.oauth2 import (
|
||||
OAuth2Error,
|
||||
ResourceProtector as _ResourceProtector
|
||||
)
|
||||
from authlib.oauth2.rfc6749 import (
|
||||
MissingAuthorizationError,
|
||||
)
|
||||
from .requests import FlaskJsonRequest
|
||||
from .signals import token_authenticated
|
||||
from .errors import raise_http_exception
|
||||
|
||||
|
||||
class ResourceProtector(_ResourceProtector):
|
||||
"""A protecting method for resource servers. Creating a ``require_oauth``
|
||||
decorator easily with ResourceProtector::
|
||||
|
||||
from authlib.integrations.flask_oauth2 import ResourceProtector
|
||||
|
||||
require_oauth = ResourceProtector()
|
||||
|
||||
# add bearer token validator
|
||||
from authlib.oauth2.rfc6750 import BearerTokenValidator
|
||||
from project.models import Token
|
||||
|
||||
class MyBearerTokenValidator(BearerTokenValidator):
|
||||
def authenticate_token(self, token_string):
|
||||
return Token.query.filter_by(access_token=token_string).first()
|
||||
|
||||
require_oauth.register_token_validator(MyBearerTokenValidator())
|
||||
|
||||
# protect resource with require_oauth
|
||||
|
||||
@app.route('/user')
|
||||
@require_oauth(['profile'])
|
||||
def user_profile():
|
||||
user = User.get(current_token.user_id)
|
||||
return jsonify(user.to_dict())
|
||||
|
||||
"""
|
||||
def raise_error_response(self, error):
|
||||
"""Raise HTTPException for OAuth2Error. Developers can re-implement
|
||||
this method to customize the error response.
|
||||
|
||||
:param error: OAuth2Error
|
||||
:raise: HTTPException
|
||||
"""
|
||||
status = error.status_code
|
||||
body = json.dumps(dict(error.get_body()))
|
||||
headers = error.get_headers()
|
||||
raise_http_exception(status, body, headers)
|
||||
|
||||
def acquire_token(self, scopes=None, **kwargs):
|
||||
"""A method to acquire current valid token with the given scope.
|
||||
|
||||
:param scopes: a list of scope values
|
||||
:return: token object
|
||||
"""
|
||||
request = FlaskJsonRequest(_req)
|
||||
# backward compatibility
|
||||
kwargs['scopes'] = scopes
|
||||
for claim in kwargs:
|
||||
if isinstance(kwargs[claim], str):
|
||||
kwargs[claim] = [kwargs[claim]]
|
||||
token = self.validate_request(request=request, **kwargs)
|
||||
token_authenticated.send(self, token=token)
|
||||
g.authlib_server_oauth2_token = token
|
||||
return token
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, scopes=None):
|
||||
"""The with statement of ``require_oauth``. Instead of using a
|
||||
decorator, you can use a with statement instead::
|
||||
|
||||
@app.route('/api/user')
|
||||
def user_api():
|
||||
with require_oauth.acquire('profile') as token:
|
||||
user = User.get(token.user_id)
|
||||
return jsonify(user.to_dict())
|
||||
"""
|
||||
try:
|
||||
yield self.acquire_token(scopes)
|
||||
except OAuth2Error as error:
|
||||
self.raise_error_response(error)
|
||||
|
||||
def __call__(self, scopes=None, optional=False, **kwargs):
|
||||
claims = kwargs
|
||||
# backward compatibility
|
||||
claims['scopes'] = scopes
|
||||
def wrapper(f):
|
||||
@functools.wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
try:
|
||||
self.acquire_token(**claims)
|
||||
except MissingAuthorizationError as error:
|
||||
if optional:
|
||||
return f(*args, **kwargs)
|
||||
self.raise_error_response(error)
|
||||
except OAuth2Error as error:
|
||||
self.raise_error_response(error)
|
||||
return f(*args, **kwargs)
|
||||
return decorated
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_current_token():
|
||||
return g.get('authlib_server_oauth2_token')
|
||||
|
||||
|
||||
current_token = LocalProxy(_get_current_token)
|
||||
@@ -0,0 +1,12 @@
|
||||
from flask.signals import Namespace
|
||||
|
||||
_signal = Namespace()
|
||||
|
||||
#: signal when client is authenticated
|
||||
client_authenticated = _signal.signal('client_authenticated')
|
||||
|
||||
#: signal when token is revoked
|
||||
token_revoked = _signal.signal('token_revoked')
|
||||
|
||||
#: signal when token is authenticated
|
||||
token_authenticated = _signal.signal('token_authenticated')
|
||||
@@ -0,0 +1,25 @@
|
||||
from authlib.oauth1 import (
|
||||
SIGNATURE_HMAC_SHA1,
|
||||
SIGNATURE_RSA_SHA1,
|
||||
SIGNATURE_PLAINTEXT,
|
||||
SIGNATURE_TYPE_HEADER,
|
||||
SIGNATURE_TYPE_QUERY,
|
||||
SIGNATURE_TYPE_BODY,
|
||||
)
|
||||
from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client
|
||||
from .oauth2_client import (
|
||||
OAuth2Auth, OAuth2Client, OAuth2ClientAuth,
|
||||
AsyncOAuth2Client,
|
||||
)
|
||||
from .assertion_client import AssertionClient, AsyncAssertionClient
|
||||
from ..base_client import OAuthError
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuthError',
|
||||
'OAuth1Auth', 'AsyncOAuth1Client',
|
||||
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
|
||||
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
|
||||
'OAuth2Auth', 'OAuth2ClientAuth', 'OAuth2Client', 'AsyncOAuth2Client',
|
||||
'AssertionClient', 'AsyncAssertionClient',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
import httpx
|
||||
from httpx import Response, USE_CLIENT_DEFAULT
|
||||
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
|
||||
from authlib.oauth2.rfc7523 import JWTBearerGrant
|
||||
from .utils import extract_client_kwargs
|
||||
from .oauth2_client import OAuth2Auth
|
||||
from ..base_client import OAuthError
|
||||
|
||||
__all__ = ['AsyncAssertionClient']
|
||||
|
||||
|
||||
class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient):
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
|
||||
ASSERTION_METHODS = {
|
||||
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
|
||||
}
|
||||
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
|
||||
|
||||
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
|
||||
claims=None, token_placement='header', scope=None, **kwargs):
|
||||
|
||||
client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.AsyncClient.__init__(self, **client_kwargs)
|
||||
|
||||
_AssertionClient.__init__(
|
||||
self, session=None,
|
||||
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
|
||||
audience=audience, grant_type=grant_type, claims=claims,
|
||||
token_placement=token_placement, scope=scope, **kwargs
|
||||
)
|
||||
|
||||
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs) -> Response:
|
||||
"""Send request with auto refresh token feature."""
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token or self.token.is_expired():
|
||||
await self.refresh_token()
|
||||
|
||||
auth = self.token_auth
|
||||
return await super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
|
||||
async def _refresh_token(self, data):
|
||||
resp = await self.request(
|
||||
'POST', self.token_endpoint, data=data, withhold_token=True)
|
||||
|
||||
return self.parse_response_token(resp)
|
||||
|
||||
|
||||
class AssertionClient(_AssertionClient, httpx.Client):
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
|
||||
ASSERTION_METHODS = {
|
||||
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
|
||||
}
|
||||
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
|
||||
|
||||
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
|
||||
claims=None, token_placement='header', scope=None, **kwargs):
|
||||
|
||||
client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.Client.__init__(self, **client_kwargs)
|
||||
|
||||
_AssertionClient.__init__(
|
||||
self, session=self,
|
||||
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
|
||||
audience=audience, grant_type=grant_type, claims=claims,
|
||||
token_placement=token_placement, scope=scope, **kwargs
|
||||
)
|
||||
|
||||
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
"""Send request with auto refresh token feature."""
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token or self.token.is_expired():
|
||||
self.refresh_token()
|
||||
|
||||
auth = self.token_auth
|
||||
return super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
@@ -0,0 +1,102 @@
|
||||
import typing
|
||||
import httpx
|
||||
from httpx import Auth, Request, Response
|
||||
from authlib.oauth1 import (
|
||||
SIGNATURE_HMAC_SHA1,
|
||||
SIGNATURE_TYPE_HEADER,
|
||||
)
|
||||
from authlib.common.encoding import to_unicode
|
||||
from authlib.oauth1 import ClientAuth
|
||||
from authlib.oauth1.client import OAuth1Client as _OAuth1Client
|
||||
from .utils import build_request, extract_client_kwargs
|
||||
from ..base_client import OAuthError
|
||||
|
||||
|
||||
class OAuth1Auth(Auth, ClientAuth):
|
||||
"""Signs the httpx request using OAuth 1 (RFC5849)"""
|
||||
requires_request_body = True
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
url, headers, body = self.prepare(
|
||||
request.method, str(request.url), request.headers, request.content)
|
||||
headers['Content-Length'] = str(len(body))
|
||||
yield build_request(url=url, headers=headers, body=body, initial_request=request)
|
||||
|
||||
|
||||
class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient):
|
||||
auth_class = OAuth1Auth
|
||||
|
||||
def __init__(self, client_id, client_secret=None,
|
||||
token=None, token_secret=None,
|
||||
redirect_uri=None, rsa_key=None, verifier=None,
|
||||
signature_method=SIGNATURE_HMAC_SHA1,
|
||||
signature_type=SIGNATURE_TYPE_HEADER,
|
||||
force_include_body=False, **kwargs):
|
||||
|
||||
_client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.AsyncClient.__init__(self, **_client_kwargs)
|
||||
|
||||
_OAuth1Client.__init__(
|
||||
self, None,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token=token, token_secret=token_secret,
|
||||
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
|
||||
signature_method=signature_method, signature_type=signature_type,
|
||||
force_include_body=force_include_body, **kwargs)
|
||||
|
||||
async def fetch_access_token(self, url, verifier=None, **kwargs):
|
||||
"""Method for fetching an access token from the token endpoint.
|
||||
|
||||
This is the final step in the OAuth 1 workflow. An access token is
|
||||
obtained using all previously obtained credentials, including the
|
||||
verifier from the authorization step.
|
||||
|
||||
:param url: Access Token endpoint.
|
||||
:param verifier: A verifier string to prove authorization was granted.
|
||||
:param kwargs: Extra parameters to include for fetching access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
if verifier:
|
||||
self.auth.verifier = verifier
|
||||
if not self.auth.verifier:
|
||||
self.handle_error('missing_verifier', 'Missing "verifier" value')
|
||||
token = await self._fetch_token(url, **kwargs)
|
||||
self.auth.verifier = None
|
||||
return token
|
||||
|
||||
async def _fetch_token(self, url, **kwargs):
|
||||
resp = await self.post(url, **kwargs)
|
||||
text = await resp.aread()
|
||||
token = self.parse_response_token(resp.status_code, to_unicode(text))
|
||||
self.token = token
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def handle_error(error_type, error_description):
|
||||
raise OAuthError(error_type, error_description)
|
||||
|
||||
|
||||
class OAuth1Client(_OAuth1Client, httpx.Client):
|
||||
auth_class = OAuth1Auth
|
||||
|
||||
def __init__(self, client_id, client_secret=None,
|
||||
token=None, token_secret=None,
|
||||
redirect_uri=None, rsa_key=None, verifier=None,
|
||||
signature_method=SIGNATURE_HMAC_SHA1,
|
||||
signature_type=SIGNATURE_TYPE_HEADER,
|
||||
force_include_body=False, **kwargs):
|
||||
|
||||
_client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.Client.__init__(self, **_client_kwargs)
|
||||
|
||||
_OAuth1Client.__init__(
|
||||
self, self,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token=token, token_secret=token_secret,
|
||||
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
|
||||
signature_method=signature_method, signature_type=signature_type,
|
||||
force_include_body=force_include_body, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def handle_error(error_type, error_description):
|
||||
raise OAuthError(error_type, error_description)
|
||||
@@ -0,0 +1,220 @@
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from httpx import Auth, Request, Response, USE_CLIENT_DEFAULT
|
||||
from anyio import Lock # Import after httpx so import errors refer to httpx
|
||||
from authlib.common.urls import url_decode
|
||||
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
|
||||
from authlib.oauth2.auth import ClientAuth, TokenAuth
|
||||
from .utils import HTTPX_CLIENT_KWARGS, build_request
|
||||
from ..base_client import (
|
||||
OAuthError,
|
||||
InvalidTokenError,
|
||||
MissingTokenError,
|
||||
UnsupportedTokenTypeError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'OAuth2Auth', 'OAuth2ClientAuth',
|
||||
'AsyncOAuth2Client', 'OAuth2Client',
|
||||
]
|
||||
|
||||
|
||||
class OAuth2Auth(Auth, TokenAuth):
|
||||
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
|
||||
requires_request_body = True
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
try:
|
||||
url, headers, body = self.prepare(
|
||||
str(request.url), request.headers, request.content)
|
||||
headers['Content-Length'] = str(len(body))
|
||||
yield build_request(url=url, headers=headers, body=body, initial_request=request)
|
||||
except KeyError as error:
|
||||
description = f'Unsupported token_type: {str(error)}'
|
||||
raise UnsupportedTokenTypeError(description=description)
|
||||
|
||||
|
||||
class OAuth2ClientAuth(Auth, ClientAuth):
|
||||
requires_request_body = True
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
url, headers, body = self.prepare(
|
||||
request.method, str(request.url), request.headers, request.content)
|
||||
headers['Content-Length'] = str(len(body))
|
||||
yield build_request(url=url, headers=headers, body=body, initial_request=request)
|
||||
|
||||
|
||||
class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient):
|
||||
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
|
||||
|
||||
client_auth_class = OAuth2ClientAuth
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
|
||||
def __init__(self, client_id=None, client_secret=None,
|
||||
token_endpoint_auth_method=None,
|
||||
revocation_endpoint_auth_method=None,
|
||||
scope=None, redirect_uri=None,
|
||||
token=None, token_placement='header',
|
||||
update_token=None, leeway=60, **kwargs):
|
||||
|
||||
# extract httpx.Client kwargs
|
||||
client_kwargs = self._extract_session_request_params(kwargs)
|
||||
httpx.AsyncClient.__init__(self, **client_kwargs)
|
||||
|
||||
# We use a Lock to synchronize coroutines to prevent
|
||||
# multiple concurrent attempts to refresh the same token
|
||||
self._token_refresh_lock = Lock()
|
||||
|
||||
_OAuth2Client.__init__(
|
||||
self, session=None,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
|
||||
scope=scope, redirect_uri=redirect_uri,
|
||||
token=token, token_placement=token_placement,
|
||||
update_token=update_token, leeway=leeway, **kwargs
|
||||
)
|
||||
|
||||
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
await self.ensure_active_token(self.token)
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
return await super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
await self.ensure_active_token(self.token)
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
async with super().stream(
|
||||
method, url, auth=auth, **kwargs) as resp:
|
||||
yield resp
|
||||
|
||||
async def ensure_active_token(self, token):
|
||||
async with self._token_refresh_lock:
|
||||
if self.token.is_expired(leeway=self.leeway):
|
||||
refresh_token = token.get('refresh_token')
|
||||
url = self.metadata.get('token_endpoint')
|
||||
if refresh_token and url:
|
||||
await self.refresh_token(url, refresh_token=refresh_token)
|
||||
elif self.metadata.get('grant_type') == 'client_credentials':
|
||||
access_token = token['access_token']
|
||||
new_token = await self.fetch_token(url, grant_type='client_credentials')
|
||||
if self.update_token:
|
||||
await self.update_token(new_token, access_token=access_token)
|
||||
else:
|
||||
raise InvalidTokenError()
|
||||
|
||||
async def _fetch_token(self, url, body='', headers=None, auth=USE_CLIENT_DEFAULT,
|
||||
method='POST', **kwargs):
|
||||
if method.upper() == 'POST':
|
||||
resp = await self.post(
|
||||
url, data=dict(url_decode(body)), headers=headers,
|
||||
auth=auth, **kwargs)
|
||||
else:
|
||||
if '?' in url:
|
||||
url = '&'.join([url, body])
|
||||
else:
|
||||
url = '?'.join([url, body])
|
||||
resp = await self.get(url, headers=headers, auth=auth, **kwargs)
|
||||
|
||||
for hook in self.compliance_hook['access_token_response']:
|
||||
resp = hook(resp)
|
||||
|
||||
return self.parse_response_token(resp)
|
||||
|
||||
async def _refresh_token(self, url, refresh_token=None, body='',
|
||||
headers=None, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
resp = await self.post(
|
||||
url, data=dict(url_decode(body)), headers=headers,
|
||||
auth=auth, **kwargs)
|
||||
|
||||
for hook in self.compliance_hook['refresh_token_response']:
|
||||
resp = hook(resp)
|
||||
|
||||
token = self.parse_response_token(resp)
|
||||
if 'refresh_token' not in token:
|
||||
self.token['refresh_token'] = refresh_token
|
||||
|
||||
if self.update_token:
|
||||
await self.update_token(self.token, refresh_token=refresh_token)
|
||||
|
||||
return self.token
|
||||
|
||||
def _http_post(self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs):
|
||||
return self.post(
|
||||
url, data=dict(url_decode(body)),
|
||||
headers=headers, auth=auth, **kwargs)
|
||||
|
||||
|
||||
class OAuth2Client(_OAuth2Client, httpx.Client):
|
||||
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
|
||||
|
||||
client_auth_class = OAuth2ClientAuth
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
|
||||
def __init__(self, client_id=None, client_secret=None,
|
||||
token_endpoint_auth_method=None,
|
||||
revocation_endpoint_auth_method=None,
|
||||
scope=None, redirect_uri=None,
|
||||
token=None, token_placement='header',
|
||||
update_token=None, **kwargs):
|
||||
|
||||
# extract httpx.Client kwargs
|
||||
client_kwargs = self._extract_session_request_params(kwargs)
|
||||
httpx.Client.__init__(self, **client_kwargs)
|
||||
|
||||
_OAuth2Client.__init__(
|
||||
self, session=self,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
|
||||
scope=scope, redirect_uri=redirect_uri,
|
||||
token=token, token_placement=token_placement,
|
||||
update_token=update_token, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_error(error_type, error_description):
|
||||
raise OAuthError(error_type, error_description)
|
||||
|
||||
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
if not self.ensure_active_token(self.token):
|
||||
raise InvalidTokenError()
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
return super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
|
||||
def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
if not self.ensure_active_token(self.token):
|
||||
raise InvalidTokenError()
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
return super().stream(
|
||||
method, url, auth=auth, **kwargs)
|
||||
@@ -0,0 +1,30 @@
|
||||
from httpx import Request
|
||||
|
||||
HTTPX_CLIENT_KWARGS = [
|
||||
'headers', 'cookies', 'verify', 'cert', 'http1', 'http2',
|
||||
'proxies', 'timeout', 'follow_redirects', 'limits', 'max_redirects',
|
||||
'event_hooks', 'base_url', 'transport', 'app', 'trust_env',
|
||||
]
|
||||
|
||||
|
||||
def extract_client_kwargs(kwargs):
|
||||
client_kwargs = {}
|
||||
for k in HTTPX_CLIENT_KWARGS:
|
||||
if k in kwargs:
|
||||
client_kwargs[k] = kwargs.pop(k)
|
||||
return client_kwargs
|
||||
|
||||
|
||||
def build_request(url, headers, body, initial_request: Request) -> Request:
|
||||
"""Make sure that all the data from initial request is passed to the updated object"""
|
||||
updated_request = Request(
|
||||
method=initial_request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body
|
||||
)
|
||||
|
||||
if hasattr(initial_request, 'extensions'):
|
||||
updated_request.extensions = initial_request.extensions
|
||||
|
||||
return updated_request
|
||||
@@ -0,0 +1,22 @@
|
||||
from .oauth1_session import OAuth1Session, OAuth1Auth
|
||||
from .oauth2_session import OAuth2Session, OAuth2Auth
|
||||
from .assertion_session import AssertionSession
|
||||
from ..base_client import OAuthError
|
||||
from authlib.oauth1 import (
|
||||
SIGNATURE_HMAC_SHA1,
|
||||
SIGNATURE_RSA_SHA1,
|
||||
SIGNATURE_PLAINTEXT,
|
||||
SIGNATURE_TYPE_HEADER,
|
||||
SIGNATURE_TYPE_QUERY,
|
||||
SIGNATURE_TYPE_BODY,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuthError',
|
||||
'OAuth1Session', 'OAuth1Auth',
|
||||
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
|
||||
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
|
||||
'OAuth2Session', 'OAuth2Auth',
|
||||
'AssertionSession',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user