venv added, updated
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
from .registry import BaseOAuth
|
||||
from .sync_app import BaseApp, OAuth1Mixin, OAuth2Mixin
|
||||
from .sync_openid import OpenIDMixin
|
||||
from .framework_integration import FrameworkIntegration
|
||||
from .errors import (
|
||||
OAuthError, MissingRequestTokenError, MissingTokenError,
|
||||
TokenExpiredError, InvalidTokenError, UnsupportedTokenTypeError,
|
||||
MismatchingStateError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'BaseOAuth',
|
||||
'BaseApp', 'OAuth1Mixin', 'OAuth2Mixin',
|
||||
'OpenIDMixin', 'FrameworkIntegration',
|
||||
'OAuthError', 'MissingRequestTokenError', 'MissingTokenError',
|
||||
'TokenExpiredError', 'InvalidTokenError', 'UnsupportedTokenTypeError',
|
||||
'MismatchingStateError',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,144 @@
|
||||
import time
|
||||
import logging
|
||||
from authlib.common.urls import urlparse
|
||||
from .errors import (
|
||||
MissingRequestTokenError,
|
||||
MissingTokenError,
|
||||
)
|
||||
from .sync_app import OAuth1Base, OAuth2Base
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['AsyncOAuth1Mixin', 'AsyncOAuth2Mixin']
|
||||
|
||||
|
||||
class AsyncOAuth1Mixin(OAuth1Base):
|
||||
async def request(self, method, url, token=None, **kwargs):
|
||||
async with self._get_oauth_client() as session:
|
||||
return await _http_request(self, session, method, url, token, kwargs)
|
||||
|
||||
async def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
if not self.authorize_url:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
async with self._get_oauth_client() as client:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = {}
|
||||
if self.request_token_params:
|
||||
params.update(self.request_token_params)
|
||||
request_token = await client.fetch_request_token(self.request_token_url, **params)
|
||||
log.debug(f'Fetch request token: {request_token!r}')
|
||||
url = client.create_authorization_url(self.authorize_url, **kwargs)
|
||||
state = request_token['oauth_token']
|
||||
return {'url': url, 'request_token': request_token, 'state': state}
|
||||
|
||||
async def fetch_access_token(self, request_token=None, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:param request_token: A previous request token for OAuth 1.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
async with self._get_oauth_client() as client:
|
||||
if request_token is None:
|
||||
raise MissingRequestTokenError()
|
||||
# merge request token with verifier
|
||||
token = {}
|
||||
token.update(request_token)
|
||||
token.update(kwargs)
|
||||
client.token = token
|
||||
params = self.access_token_params or {}
|
||||
token = await client.fetch_access_token(self.access_token_url, **params)
|
||||
return token
|
||||
|
||||
|
||||
class AsyncOAuth2Mixin(OAuth2Base):
|
||||
async def _on_update_token(self, token, refresh_token=None, access_token=None):
|
||||
if self._update_token:
|
||||
await self._update_token(
|
||||
token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
async def load_server_metadata(self):
|
||||
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
|
||||
async with self.client_cls(**self.client_kwargs) as client:
|
||||
resp = await client.request('GET', self._server_metadata_url, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
metadata = resp.json()
|
||||
metadata['_loaded_at'] = time.time()
|
||||
self.server_metadata.update(metadata)
|
||||
return self.server_metadata
|
||||
|
||||
async def request(self, method, url, token=None, **kwargs):
|
||||
metadata = await self.load_server_metadata()
|
||||
async with self._get_oauth_client(**metadata) as session:
|
||||
return await _http_request(self, session, method, url, token, kwargs)
|
||||
|
||||
async def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
metadata = await self.load_server_metadata()
|
||||
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
|
||||
if not authorization_endpoint:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
async with self._get_oauth_client(**metadata) as client:
|
||||
client.redirect_uri = redirect_uri
|
||||
return self._create_oauth2_authorization_url(
|
||||
client, authorization_endpoint, **kwargs)
|
||||
|
||||
async def fetch_access_token(self, redirect_uri=None, **kwargs):
|
||||
"""Fetch access token in the final step.
|
||||
|
||||
:param redirect_uri: Callback or Redirect URI that is used in
|
||||
previous :meth:`authorize_redirect`.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
metadata = await self.load_server_metadata()
|
||||
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
|
||||
async with self._get_oauth_client(**metadata) as client:
|
||||
if redirect_uri is not None:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = {}
|
||||
if self.access_token_params:
|
||||
params.update(self.access_token_params)
|
||||
params.update(kwargs)
|
||||
token = await client.fetch_token(token_endpoint, **params)
|
||||
return token
|
||||
|
||||
|
||||
async def _http_request(ctx, session, method, url, token, kwargs):
|
||||
request = kwargs.pop('request', None)
|
||||
withhold_token = kwargs.get('withhold_token')
|
||||
if ctx.api_base_url and not url.startswith(('https://', 'http://')):
|
||||
url = urlparse.urljoin(ctx.api_base_url, url)
|
||||
|
||||
if withhold_token:
|
||||
return await session.request(method, url, **kwargs)
|
||||
|
||||
if token is None and ctx._fetch_token and request:
|
||||
token = await ctx._fetch_token(request)
|
||||
if token is None:
|
||||
raise MissingTokenError()
|
||||
|
||||
session.token = token
|
||||
return await session.request(method, url, **kwargs)
|
||||
@@ -0,0 +1,79 @@
|
||||
from authlib.jose import JsonWebToken, JsonWebKey
|
||||
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
|
||||
|
||||
__all__ = ['AsyncOpenIDMixin']
|
||||
|
||||
|
||||
class AsyncOpenIDMixin:
|
||||
async def fetch_jwk_set(self, force=False):
|
||||
metadata = await self.load_server_metadata()
|
||||
jwk_set = metadata.get('jwks')
|
||||
if jwk_set and not force:
|
||||
return jwk_set
|
||||
|
||||
uri = metadata.get('jwks_uri')
|
||||
if not uri:
|
||||
raise RuntimeError('Missing "jwks_uri" in metadata')
|
||||
|
||||
async with self.client_cls(**self.client_kwargs) as client:
|
||||
resp = await client.request('GET', uri, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
jwk_set = resp.json()
|
||||
|
||||
self.server_metadata['jwks'] = jwk_set
|
||||
return jwk_set
|
||||
|
||||
async def userinfo(self, **kwargs):
|
||||
"""Fetch user info from ``userinfo_endpoint``."""
|
||||
metadata = await self.load_server_metadata()
|
||||
resp = await self.get(metadata['userinfo_endpoint'], **kwargs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return UserInfo(data)
|
||||
|
||||
async def parse_id_token(self, token, nonce, claims_options=None):
|
||||
"""Return an instance of UserInfo from token's ``id_token``."""
|
||||
claims_params = dict(
|
||||
nonce=nonce,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
if 'access_token' in token:
|
||||
claims_params['access_token'] = token['access_token']
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
metadata = await self.load_server_metadata()
|
||||
if claims_options is None and 'issuer' in metadata:
|
||||
claims_options = {'iss': {'values': [metadata['issuer']]}}
|
||||
|
||||
alg_values = metadata.get('id_token_signing_alg_values_supported')
|
||||
if not alg_values:
|
||||
alg_values = ['RS256']
|
||||
|
||||
jwt = JsonWebToken(alg_values)
|
||||
|
||||
jwk_set = await self.fetch_jwk_set()
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
token['id_token'],
|
||||
key=JsonWebKey.import_key_set(jwk_set),
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
except ValueError:
|
||||
jwk_set = await self.fetch_jwk_set(force=True)
|
||||
claims = jwt.decode(
|
||||
token['id_token'],
|
||||
key=JsonWebKey.import_key_set(jwk_set),
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
|
||||
# https://github.com/lepture/authlib/issues/259
|
||||
if claims.get('nonce_supported') is False:
|
||||
claims.params['nonce'] = None
|
||||
claims.validate(leeway=120)
|
||||
return UserInfo(claims)
|
||||
@@ -0,0 +1,30 @@
|
||||
from authlib.common.errors import AuthlibBaseError
|
||||
|
||||
|
||||
class OAuthError(AuthlibBaseError):
|
||||
error = 'oauth_error'
|
||||
|
||||
|
||||
class MissingRequestTokenError(OAuthError):
|
||||
error = 'missing_request_token'
|
||||
|
||||
|
||||
class MissingTokenError(OAuthError):
|
||||
error = 'missing_token'
|
||||
|
||||
|
||||
class TokenExpiredError(OAuthError):
|
||||
error = 'token_expired'
|
||||
|
||||
|
||||
class InvalidTokenError(OAuthError):
|
||||
error = 'token_invalid'
|
||||
|
||||
|
||||
class UnsupportedTokenTypeError(OAuthError):
|
||||
error = 'unsupported_token_type'
|
||||
|
||||
|
||||
class MismatchingStateError(OAuthError):
|
||||
error = 'mismatching_state'
|
||||
description = 'CSRF Warning! State not equal in request and response.'
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
|
||||
class FrameworkIntegration:
|
||||
expires_in = 3600
|
||||
|
||||
def __init__(self, name, cache=None):
|
||||
self.name = name
|
||||
self.cache = cache
|
||||
|
||||
def _get_cache_data(self, key):
|
||||
value = self.cache.get(key)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def _clear_session_state(self, session):
|
||||
now = time.time()
|
||||
for key in dict(session):
|
||||
if '_authlib_' in key:
|
||||
# TODO: remove in future
|
||||
session.pop(key)
|
||||
elif key.startswith('_state_'):
|
||||
value = session[key]
|
||||
exp = value.get('exp')
|
||||
if not exp or exp < now:
|
||||
session.pop(key)
|
||||
|
||||
def get_state_data(self, session, state):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
value = self._get_cache_data(key)
|
||||
else:
|
||||
value = session.get(key)
|
||||
if value:
|
||||
return value.get('data')
|
||||
return None
|
||||
|
||||
def set_state_data(self, session, state, data):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
self.cache.set(key, json.dumps({'data': data}), self.expires_in)
|
||||
else:
|
||||
now = time.time()
|
||||
session[key] = {'data': data, 'exp': now + self.expires_in}
|
||||
|
||||
def clear_state_data(self, session, state):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
self.cache.delete(key)
|
||||
else:
|
||||
session.pop(key, None)
|
||||
self._clear_session_state(session)
|
||||
|
||||
def update_token(self, token, refresh_token=None, access_token=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def load_config(oauth, name, params):
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,131 @@
|
||||
import functools
|
||||
from .framework_integration import FrameworkIntegration
|
||||
|
||||
__all__ = ['BaseOAuth']
|
||||
|
||||
|
||||
OAUTH_CLIENT_PARAMS = (
|
||||
'client_id', 'client_secret',
|
||||
'request_token_url', 'request_token_params',
|
||||
'access_token_url', 'access_token_params',
|
||||
'refresh_token_url', 'refresh_token_params',
|
||||
'authorize_url', 'authorize_params',
|
||||
'api_base_url', 'client_kwargs',
|
||||
'server_metadata_url',
|
||||
)
|
||||
|
||||
|
||||
class BaseOAuth:
|
||||
"""Registry for oauth clients.
|
||||
|
||||
Create an instance for registry::
|
||||
|
||||
oauth = OAuth()
|
||||
"""
|
||||
oauth1_client_cls = None
|
||||
oauth2_client_cls = None
|
||||
framework_integration_cls = FrameworkIntegration
|
||||
|
||||
def __init__(self, cache=None, fetch_token=None, update_token=None):
|
||||
self._registry = {}
|
||||
self._clients = {}
|
||||
self.cache = cache
|
||||
self.fetch_token = fetch_token
|
||||
self.update_token = update_token
|
||||
|
||||
def create_client(self, name):
|
||||
"""Create or get the given named OAuth client. For instance, the
|
||||
OAuth registry has ``.register`` a twitter client, developers may
|
||||
access the client with::
|
||||
|
||||
client = oauth.create_client('twitter')
|
||||
|
||||
:param: name: Name of the remote application
|
||||
:return: OAuth remote app
|
||||
"""
|
||||
if name in self._clients:
|
||||
return self._clients[name]
|
||||
|
||||
if name not in self._registry:
|
||||
return None
|
||||
|
||||
overwrite, config = self._registry[name]
|
||||
client_cls = config.pop('client_cls', None)
|
||||
|
||||
if client_cls and client_cls.OAUTH_APP_CONFIG:
|
||||
kwargs = client_cls.OAUTH_APP_CONFIG
|
||||
kwargs.update(config)
|
||||
else:
|
||||
kwargs = config
|
||||
|
||||
kwargs = self.generate_client_kwargs(name, overwrite, **kwargs)
|
||||
framework = self.framework_integration_cls(name, self.cache)
|
||||
if client_cls:
|
||||
client = client_cls(framework, name, **kwargs)
|
||||
elif kwargs.get('request_token_url'):
|
||||
client = self.oauth1_client_cls(framework, name, **kwargs)
|
||||
else:
|
||||
client = self.oauth2_client_cls(framework, name, **kwargs)
|
||||
|
||||
self._clients[name] = client
|
||||
return client
|
||||
|
||||
def register(self, name, overwrite=False, **kwargs):
|
||||
"""Registers a new remote application.
|
||||
|
||||
:param name: Name of the remote application.
|
||||
:param overwrite: Overwrite existing config with framework settings.
|
||||
:param kwargs: Parameters for :class:`RemoteApp`.
|
||||
|
||||
Find parameters for the given remote app class. When a remote app is
|
||||
registered, it can be accessed with *named* attribute::
|
||||
|
||||
oauth.register('twitter', client_id='', ...)
|
||||
oauth.twitter.get('timeline')
|
||||
"""
|
||||
self._registry[name] = (overwrite, kwargs)
|
||||
return self.create_client(name)
|
||||
|
||||
def generate_client_kwargs(self, name, overwrite, **kwargs):
|
||||
fetch_token = kwargs.pop('fetch_token', None)
|
||||
update_token = kwargs.pop('update_token', None)
|
||||
|
||||
config = self.load_config(name, OAUTH_CLIENT_PARAMS)
|
||||
if config:
|
||||
kwargs = _config_client(config, kwargs, overwrite)
|
||||
|
||||
if not fetch_token and self.fetch_token:
|
||||
fetch_token = functools.partial(self.fetch_token, name)
|
||||
|
||||
kwargs['fetch_token'] = fetch_token
|
||||
|
||||
if not kwargs.get('request_token_url'):
|
||||
if not update_token and self.update_token:
|
||||
update_token = functools.partial(self.update_token, name)
|
||||
|
||||
kwargs['update_token'] = update_token
|
||||
return kwargs
|
||||
|
||||
def load_config(self, name, params):
|
||||
return self.framework_integration_cls.load_config(self, name, params)
|
||||
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
return object.__getattribute__(self, key)
|
||||
except AttributeError:
|
||||
if key in self._registry:
|
||||
return self.create_client(key)
|
||||
raise AttributeError('No such client: %s' % key)
|
||||
|
||||
|
||||
def _config_client(config, kwargs, overwrite):
|
||||
for k in OAUTH_CLIENT_PARAMS:
|
||||
v = config.get(k, None)
|
||||
if k not in kwargs:
|
||||
kwargs[k] = v
|
||||
elif overwrite and v:
|
||||
if isinstance(kwargs[k], dict):
|
||||
kwargs[k].update(v)
|
||||
else:
|
||||
kwargs[k] = v
|
||||
return kwargs
|
||||
@@ -0,0 +1,348 @@
|
||||
import time
|
||||
import logging
|
||||
from authlib.common.urls import urlparse
|
||||
from authlib.consts import default_user_agent
|
||||
from authlib.common.security import generate_token
|
||||
from .errors import (
|
||||
MismatchingStateError,
|
||||
MissingRequestTokenError,
|
||||
MissingTokenError,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseApp:
|
||||
client_cls = None
|
||||
OAUTH_APP_CONFIG = None
|
||||
|
||||
def request(self, method, url, token=None, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
"""Invoke GET http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.get('users/lepture')
|
||||
"""
|
||||
return self.request('GET', url, **kwargs)
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
"""Invoke POST http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.post('timeline', json={'text': 'Hi'})
|
||||
"""
|
||||
return self.request('POST', url, **kwargs)
|
||||
|
||||
def patch(self, url, **kwargs):
|
||||
"""Invoke PATCH http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.patch('profile', json={'name': 'Hsiaoming Yang'})
|
||||
"""
|
||||
return self.request('PATCH', url, **kwargs)
|
||||
|
||||
def put(self, url, **kwargs):
|
||||
"""Invoke PUT http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.put('profile', json={'name': 'Hsiaoming Yang'})
|
||||
"""
|
||||
return self.request('PUT', url, **kwargs)
|
||||
|
||||
def delete(self, url, **kwargs):
|
||||
"""Invoke DELETE http request.
|
||||
|
||||
If ``api_base_url`` configured, shortcut is available::
|
||||
|
||||
client.delete('posts/123')
|
||||
"""
|
||||
return self.request('DELETE', url, **kwargs)
|
||||
|
||||
|
||||
class _RequestMixin:
|
||||
def _get_requested_token(self, request):
|
||||
if self._fetch_token and request:
|
||||
return self._fetch_token(request)
|
||||
|
||||
def _send_token_request(self, session, method, url, token, kwargs):
|
||||
request = kwargs.pop('request', None)
|
||||
withhold_token = kwargs.get('withhold_token')
|
||||
if self.api_base_url and not url.startswith(('https://', 'http://')):
|
||||
url = urlparse.urljoin(self.api_base_url, url)
|
||||
|
||||
if withhold_token:
|
||||
return session.request(method, url, **kwargs)
|
||||
|
||||
if token is None:
|
||||
token = self._get_requested_token(request)
|
||||
|
||||
if token is None:
|
||||
raise MissingTokenError()
|
||||
|
||||
session.token = token
|
||||
return session.request(method, url, **kwargs)
|
||||
|
||||
|
||||
class OAuth1Base:
|
||||
client_cls = None
|
||||
|
||||
def __init__(
|
||||
self, framework, name=None, fetch_token=None,
|
||||
client_id=None, client_secret=None,
|
||||
request_token_url=None, request_token_params=None,
|
||||
access_token_url=None, access_token_params=None,
|
||||
authorize_url=None, authorize_params=None,
|
||||
api_base_url=None, client_kwargs=None, user_agent=None, **kwargs):
|
||||
self.framework = framework
|
||||
self.name = name
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.request_token_url = request_token_url
|
||||
self.request_token_params = request_token_params
|
||||
self.access_token_url = access_token_url
|
||||
self.access_token_params = access_token_params
|
||||
self.authorize_url = authorize_url
|
||||
self.authorize_params = authorize_params
|
||||
self.api_base_url = api_base_url
|
||||
self.client_kwargs = client_kwargs or {}
|
||||
|
||||
self._fetch_token = fetch_token
|
||||
self._user_agent = user_agent or default_user_agent
|
||||
self._kwargs = kwargs
|
||||
|
||||
def _get_oauth_client(self):
|
||||
session = self.client_cls(self.client_id, self.client_secret, **self.client_kwargs)
|
||||
session.headers['User-Agent'] = self._user_agent
|
||||
return session
|
||||
|
||||
|
||||
class OAuth1Mixin(_RequestMixin, OAuth1Base):
|
||||
def request(self, method, url, token=None, **kwargs):
|
||||
with self._get_oauth_client() as session:
|
||||
return self._send_token_request(session, method, url, token, kwargs)
|
||||
|
||||
def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
if not self.authorize_url:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
with self._get_oauth_client() as client:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = self.request_token_params or {}
|
||||
request_token = client.fetch_request_token(self.request_token_url, **params)
|
||||
log.debug(f'Fetch request token: {request_token!r}')
|
||||
url = client.create_authorization_url(self.authorize_url, **kwargs)
|
||||
state = request_token['oauth_token']
|
||||
return {'url': url, 'request_token': request_token, 'state': state}
|
||||
|
||||
def fetch_access_token(self, request_token=None, **kwargs):
|
||||
"""Fetch access token in one step.
|
||||
|
||||
:param request_token: A previous request token for OAuth 1.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
with self._get_oauth_client() as client:
|
||||
if request_token is None:
|
||||
raise MissingRequestTokenError()
|
||||
# merge request token with verifier
|
||||
token = {}
|
||||
token.update(request_token)
|
||||
token.update(kwargs)
|
||||
client.token = token
|
||||
params = self.access_token_params or {}
|
||||
token = client.fetch_access_token(self.access_token_url, **params)
|
||||
return token
|
||||
|
||||
|
||||
class OAuth2Base:
|
||||
client_cls = None
|
||||
|
||||
def __init__(
|
||||
self, framework, name=None, fetch_token=None, update_token=None,
|
||||
client_id=None, client_secret=None,
|
||||
access_token_url=None, access_token_params=None,
|
||||
authorize_url=None, authorize_params=None,
|
||||
api_base_url=None, client_kwargs=None, server_metadata_url=None,
|
||||
compliance_fix=None, client_auth_methods=None, user_agent=None, **kwargs):
|
||||
self.framework = framework
|
||||
self.name = name
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.access_token_url = access_token_url
|
||||
self.access_token_params = access_token_params
|
||||
self.authorize_url = authorize_url
|
||||
self.authorize_params = authorize_params
|
||||
self.api_base_url = api_base_url
|
||||
self.client_kwargs = client_kwargs or {}
|
||||
|
||||
self.compliance_fix = compliance_fix
|
||||
self.client_auth_methods = client_auth_methods
|
||||
self._fetch_token = fetch_token
|
||||
self._update_token = update_token
|
||||
self._user_agent = user_agent or default_user_agent
|
||||
|
||||
self._server_metadata_url = server_metadata_url
|
||||
self.server_metadata = kwargs
|
||||
|
||||
def _on_update_token(self, token, refresh_token=None, access_token=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_oauth_client(self, **metadata):
|
||||
client_kwargs = {}
|
||||
client_kwargs.update(self.client_kwargs)
|
||||
client_kwargs.update(metadata)
|
||||
|
||||
if self.authorize_url:
|
||||
client_kwargs['authorization_endpoint'] = self.authorize_url
|
||||
if self.access_token_url:
|
||||
client_kwargs['token_endpoint'] = self.access_token_url
|
||||
|
||||
session = self.client_cls(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
update_token=self._on_update_token,
|
||||
**client_kwargs
|
||||
)
|
||||
if self.client_auth_methods:
|
||||
for f in self.client_auth_methods:
|
||||
session.register_client_auth_method(f)
|
||||
|
||||
if self.compliance_fix:
|
||||
self.compliance_fix(session)
|
||||
|
||||
session.headers['User-Agent'] = self._user_agent
|
||||
return session
|
||||
|
||||
@staticmethod
|
||||
def _format_state_params(state_data, params):
|
||||
if state_data is None:
|
||||
raise MismatchingStateError()
|
||||
|
||||
code_verifier = state_data.get('code_verifier')
|
||||
if code_verifier:
|
||||
params['code_verifier'] = code_verifier
|
||||
|
||||
redirect_uri = state_data.get('redirect_uri')
|
||||
if redirect_uri:
|
||||
params['redirect_uri'] = redirect_uri
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _create_oauth2_authorization_url(client, authorization_endpoint, **kwargs):
|
||||
rv = {}
|
||||
if client.code_challenge_method:
|
||||
code_verifier = kwargs.get('code_verifier')
|
||||
if not code_verifier:
|
||||
code_verifier = generate_token(48)
|
||||
kwargs['code_verifier'] = code_verifier
|
||||
rv['code_verifier'] = code_verifier
|
||||
log.debug(f'Using code_verifier: {code_verifier!r}')
|
||||
|
||||
scope = kwargs.get('scope', client.scope)
|
||||
scope = (
|
||||
(scope if isinstance(scope, (list, tuple)) else scope.split())
|
||||
if scope
|
||||
else None
|
||||
)
|
||||
if scope and "openid" in scope:
|
||||
# this is an OpenID Connect service
|
||||
nonce = kwargs.get('nonce')
|
||||
if not nonce:
|
||||
nonce = generate_token(20)
|
||||
kwargs['nonce'] = nonce
|
||||
rv['nonce'] = nonce
|
||||
|
||||
url, state = client.create_authorization_url(
|
||||
authorization_endpoint, **kwargs)
|
||||
rv['url'] = url
|
||||
rv['state'] = state
|
||||
return rv
|
||||
|
||||
|
||||
class OAuth2Mixin(_RequestMixin, OAuth2Base):
|
||||
def _on_update_token(self, token, refresh_token=None, access_token=None):
|
||||
if callable(self._update_token):
|
||||
self._update_token(
|
||||
token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.framework.update_token(
|
||||
token,
|
||||
refresh_token=refresh_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
def request(self, method, url, token=None, **kwargs):
|
||||
metadata = self.load_server_metadata()
|
||||
with self._get_oauth_client(**metadata) as session:
|
||||
return self._send_token_request(session, method, url, token, kwargs)
|
||||
|
||||
def load_server_metadata(self):
|
||||
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
|
||||
with self.client_cls(**self.client_kwargs) as session:
|
||||
resp = session.request('GET', self._server_metadata_url, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
metadata = resp.json()
|
||||
|
||||
metadata['_loaded_at'] = time.time()
|
||||
self.server_metadata.update(metadata)
|
||||
return self.server_metadata
|
||||
|
||||
def create_authorization_url(self, redirect_uri=None, **kwargs):
|
||||
"""Generate the authorization url and state for HTTP redirect.
|
||||
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: dict
|
||||
"""
|
||||
metadata = self.load_server_metadata()
|
||||
authorization_endpoint = self.authorize_url or metadata.get('authorization_endpoint')
|
||||
|
||||
if not authorization_endpoint:
|
||||
raise RuntimeError('Missing "authorize_url" value')
|
||||
|
||||
if self.authorize_params:
|
||||
kwargs.update(self.authorize_params)
|
||||
|
||||
|
||||
with self._get_oauth_client(**metadata) as client:
|
||||
if redirect_uri is not None:
|
||||
client.redirect_uri = redirect_uri
|
||||
return self._create_oauth2_authorization_url(
|
||||
client, authorization_endpoint, **kwargs)
|
||||
|
||||
def fetch_access_token(self, redirect_uri=None, **kwargs):
|
||||
"""Fetch access token in the final step.
|
||||
|
||||
:param redirect_uri: Callback or Redirect URI that is used in
|
||||
previous :meth:`authorize_redirect`.
|
||||
:param kwargs: Extra parameters to fetch access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
metadata = self.load_server_metadata()
|
||||
token_endpoint = self.access_token_url or metadata.get('token_endpoint')
|
||||
with self._get_oauth_client(**metadata) as client:
|
||||
if redirect_uri is not None:
|
||||
client.redirect_uri = redirect_uri
|
||||
params = {}
|
||||
if self.access_token_params:
|
||||
params.update(self.access_token_params)
|
||||
params.update(kwargs)
|
||||
token = client.fetch_token(token_endpoint, **params)
|
||||
return token
|
||||
@@ -0,0 +1,82 @@
|
||||
from authlib.jose import jwt, JsonWebToken, JsonWebKey
|
||||
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
|
||||
|
||||
|
||||
class OpenIDMixin:
|
||||
def fetch_jwk_set(self, force=False):
|
||||
metadata = self.load_server_metadata()
|
||||
jwk_set = metadata.get('jwks')
|
||||
if jwk_set and not force:
|
||||
return jwk_set
|
||||
|
||||
uri = metadata.get('jwks_uri')
|
||||
if not uri:
|
||||
raise RuntimeError('Missing "jwks_uri" in metadata')
|
||||
|
||||
with self.client_cls(**self.client_kwargs) as session:
|
||||
resp = session.request('GET', uri, withhold_token=True)
|
||||
resp.raise_for_status()
|
||||
jwk_set = resp.json()
|
||||
|
||||
self.server_metadata['jwks'] = jwk_set
|
||||
return jwk_set
|
||||
|
||||
def userinfo(self, **kwargs):
|
||||
"""Fetch user info from ``userinfo_endpoint``."""
|
||||
metadata = self.load_server_metadata()
|
||||
resp = self.get(metadata['userinfo_endpoint'], **kwargs)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return UserInfo(data)
|
||||
|
||||
def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
|
||||
"""Return an instance of UserInfo from token's ``id_token``."""
|
||||
if 'id_token' not in token:
|
||||
return None
|
||||
|
||||
load_key = self.create_load_key()
|
||||
|
||||
claims_params = dict(
|
||||
nonce=nonce,
|
||||
client_id=self.client_id,
|
||||
)
|
||||
if 'access_token' in token:
|
||||
claims_params['access_token'] = token['access_token']
|
||||
claims_cls = CodeIDToken
|
||||
else:
|
||||
claims_cls = ImplicitIDToken
|
||||
|
||||
metadata = self.load_server_metadata()
|
||||
if claims_options is None and 'issuer' in metadata:
|
||||
claims_options = {'iss': {'values': [metadata['issuer']]}}
|
||||
|
||||
alg_values = metadata.get('id_token_signing_alg_values_supported')
|
||||
if alg_values:
|
||||
_jwt = JsonWebToken(alg_values)
|
||||
else:
|
||||
_jwt = jwt
|
||||
|
||||
claims = _jwt.decode(
|
||||
token['id_token'], key=load_key,
|
||||
claims_cls=claims_cls,
|
||||
claims_options=claims_options,
|
||||
claims_params=claims_params,
|
||||
)
|
||||
# https://github.com/lepture/authlib/issues/259
|
||||
if claims.get('nonce_supported') is False:
|
||||
claims.params['nonce'] = None
|
||||
|
||||
claims.validate(leeway=leeway)
|
||||
return UserInfo(claims)
|
||||
|
||||
def create_load_key(self):
|
||||
def load_key(header, _):
|
||||
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
|
||||
try:
|
||||
return jwk_set.find_by_kid(header.get('kid'))
|
||||
except ValueError:
|
||||
# re-try with new jwk set
|
||||
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
|
||||
return jwk_set.find_by_kid(header.get('kid'))
|
||||
|
||||
return load_key
|
||||
Reference in New Issue
Block a user