venv added, updated
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
from authlib.oauth1 import (
|
||||
SIGNATURE_HMAC_SHA1,
|
||||
SIGNATURE_RSA_SHA1,
|
||||
SIGNATURE_PLAINTEXT,
|
||||
SIGNATURE_TYPE_HEADER,
|
||||
SIGNATURE_TYPE_QUERY,
|
||||
SIGNATURE_TYPE_BODY,
|
||||
)
|
||||
from .oauth1_client import OAuth1Auth, AsyncOAuth1Client, OAuth1Client
|
||||
from .oauth2_client import (
|
||||
OAuth2Auth, OAuth2Client, OAuth2ClientAuth,
|
||||
AsyncOAuth2Client,
|
||||
)
|
||||
from .assertion_client import AssertionClient, AsyncAssertionClient
|
||||
from ..base_client import OAuthError
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuthError',
|
||||
'OAuth1Auth', 'AsyncOAuth1Client',
|
||||
'SIGNATURE_HMAC_SHA1', 'SIGNATURE_RSA_SHA1', 'SIGNATURE_PLAINTEXT',
|
||||
'SIGNATURE_TYPE_HEADER', 'SIGNATURE_TYPE_QUERY', 'SIGNATURE_TYPE_BODY',
|
||||
'OAuth2Auth', 'OAuth2ClientAuth', 'OAuth2Client', 'AsyncOAuth2Client',
|
||||
'AssertionClient', 'AsyncAssertionClient',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,81 @@
|
||||
import httpx
|
||||
from httpx import Response, USE_CLIENT_DEFAULT
|
||||
from authlib.oauth2.rfc7521 import AssertionClient as _AssertionClient
|
||||
from authlib.oauth2.rfc7523 import JWTBearerGrant
|
||||
from .utils import extract_client_kwargs
|
||||
from .oauth2_client import OAuth2Auth
|
||||
from ..base_client import OAuthError
|
||||
|
||||
__all__ = ['AsyncAssertionClient']
|
||||
|
||||
|
||||
class AsyncAssertionClient(_AssertionClient, httpx.AsyncClient):
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
|
||||
ASSERTION_METHODS = {
|
||||
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
|
||||
}
|
||||
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
|
||||
|
||||
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
|
||||
claims=None, token_placement='header', scope=None, **kwargs):
|
||||
|
||||
client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.AsyncClient.__init__(self, **client_kwargs)
|
||||
|
||||
_AssertionClient.__init__(
|
||||
self, session=None,
|
||||
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
|
||||
audience=audience, grant_type=grant_type, claims=claims,
|
||||
token_placement=token_placement, scope=scope, **kwargs
|
||||
)
|
||||
|
||||
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs) -> Response:
|
||||
"""Send request with auto refresh token feature."""
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token or self.token.is_expired():
|
||||
await self.refresh_token()
|
||||
|
||||
auth = self.token_auth
|
||||
return await super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
|
||||
async def _refresh_token(self, data):
|
||||
resp = await self.request(
|
||||
'POST', self.token_endpoint, data=data, withhold_token=True)
|
||||
|
||||
return self.parse_response_token(resp)
|
||||
|
||||
|
||||
class AssertionClient(_AssertionClient, httpx.Client):
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
JWT_BEARER_GRANT_TYPE = JWTBearerGrant.GRANT_TYPE
|
||||
ASSERTION_METHODS = {
|
||||
JWT_BEARER_GRANT_TYPE: JWTBearerGrant.sign,
|
||||
}
|
||||
DEFAULT_GRANT_TYPE = JWT_BEARER_GRANT_TYPE
|
||||
|
||||
def __init__(self, token_endpoint, issuer, subject, audience=None, grant_type=None,
|
||||
claims=None, token_placement='header', scope=None, **kwargs):
|
||||
|
||||
client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.Client.__init__(self, **client_kwargs)
|
||||
|
||||
_AssertionClient.__init__(
|
||||
self, session=self,
|
||||
token_endpoint=token_endpoint, issuer=issuer, subject=subject,
|
||||
audience=audience, grant_type=grant_type, claims=claims,
|
||||
token_placement=token_placement, scope=scope, **kwargs
|
||||
)
|
||||
|
||||
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
"""Send request with auto refresh token feature."""
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token or self.token.is_expired():
|
||||
self.refresh_token()
|
||||
|
||||
auth = self.token_auth
|
||||
return super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
@@ -0,0 +1,102 @@
|
||||
import typing
|
||||
import httpx
|
||||
from httpx import Auth, Request, Response
|
||||
from authlib.oauth1 import (
|
||||
SIGNATURE_HMAC_SHA1,
|
||||
SIGNATURE_TYPE_HEADER,
|
||||
)
|
||||
from authlib.common.encoding import to_unicode
|
||||
from authlib.oauth1 import ClientAuth
|
||||
from authlib.oauth1.client import OAuth1Client as _OAuth1Client
|
||||
from .utils import build_request, extract_client_kwargs
|
||||
from ..base_client import OAuthError
|
||||
|
||||
|
||||
class OAuth1Auth(Auth, ClientAuth):
|
||||
"""Signs the httpx request using OAuth 1 (RFC5849)"""
|
||||
requires_request_body = True
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
url, headers, body = self.prepare(
|
||||
request.method, str(request.url), request.headers, request.content)
|
||||
headers['Content-Length'] = str(len(body))
|
||||
yield build_request(url=url, headers=headers, body=body, initial_request=request)
|
||||
|
||||
|
||||
class AsyncOAuth1Client(_OAuth1Client, httpx.AsyncClient):
|
||||
auth_class = OAuth1Auth
|
||||
|
||||
def __init__(self, client_id, client_secret=None,
|
||||
token=None, token_secret=None,
|
||||
redirect_uri=None, rsa_key=None, verifier=None,
|
||||
signature_method=SIGNATURE_HMAC_SHA1,
|
||||
signature_type=SIGNATURE_TYPE_HEADER,
|
||||
force_include_body=False, **kwargs):
|
||||
|
||||
_client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.AsyncClient.__init__(self, **_client_kwargs)
|
||||
|
||||
_OAuth1Client.__init__(
|
||||
self, None,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token=token, token_secret=token_secret,
|
||||
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
|
||||
signature_method=signature_method, signature_type=signature_type,
|
||||
force_include_body=force_include_body, **kwargs)
|
||||
|
||||
async def fetch_access_token(self, url, verifier=None, **kwargs):
|
||||
"""Method for fetching an access token from the token endpoint.
|
||||
|
||||
This is the final step in the OAuth 1 workflow. An access token is
|
||||
obtained using all previously obtained credentials, including the
|
||||
verifier from the authorization step.
|
||||
|
||||
:param url: Access Token endpoint.
|
||||
:param verifier: A verifier string to prove authorization was granted.
|
||||
:param kwargs: Extra parameters to include for fetching access token.
|
||||
:return: A token dict.
|
||||
"""
|
||||
if verifier:
|
||||
self.auth.verifier = verifier
|
||||
if not self.auth.verifier:
|
||||
self.handle_error('missing_verifier', 'Missing "verifier" value')
|
||||
token = await self._fetch_token(url, **kwargs)
|
||||
self.auth.verifier = None
|
||||
return token
|
||||
|
||||
async def _fetch_token(self, url, **kwargs):
|
||||
resp = await self.post(url, **kwargs)
|
||||
text = await resp.aread()
|
||||
token = self.parse_response_token(resp.status_code, to_unicode(text))
|
||||
self.token = token
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
def handle_error(error_type, error_description):
|
||||
raise OAuthError(error_type, error_description)
|
||||
|
||||
|
||||
class OAuth1Client(_OAuth1Client, httpx.Client):
|
||||
auth_class = OAuth1Auth
|
||||
|
||||
def __init__(self, client_id, client_secret=None,
|
||||
token=None, token_secret=None,
|
||||
redirect_uri=None, rsa_key=None, verifier=None,
|
||||
signature_method=SIGNATURE_HMAC_SHA1,
|
||||
signature_type=SIGNATURE_TYPE_HEADER,
|
||||
force_include_body=False, **kwargs):
|
||||
|
||||
_client_kwargs = extract_client_kwargs(kwargs)
|
||||
httpx.Client.__init__(self, **_client_kwargs)
|
||||
|
||||
_OAuth1Client.__init__(
|
||||
self, self,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token=token, token_secret=token_secret,
|
||||
redirect_uri=redirect_uri, rsa_key=rsa_key, verifier=verifier,
|
||||
signature_method=signature_method, signature_type=signature_type,
|
||||
force_include_body=force_include_body, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def handle_error(error_type, error_description):
|
||||
raise OAuthError(error_type, error_description)
|
||||
@@ -0,0 +1,220 @@
|
||||
import typing
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import httpx
|
||||
from httpx import Auth, Request, Response, USE_CLIENT_DEFAULT
|
||||
from anyio import Lock # Import after httpx so import errors refer to httpx
|
||||
from authlib.common.urls import url_decode
|
||||
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
|
||||
from authlib.oauth2.auth import ClientAuth, TokenAuth
|
||||
from .utils import HTTPX_CLIENT_KWARGS, build_request
|
||||
from ..base_client import (
|
||||
OAuthError,
|
||||
InvalidTokenError,
|
||||
MissingTokenError,
|
||||
UnsupportedTokenTypeError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'OAuth2Auth', 'OAuth2ClientAuth',
|
||||
'AsyncOAuth2Client', 'OAuth2Client',
|
||||
]
|
||||
|
||||
|
||||
class OAuth2Auth(Auth, TokenAuth):
|
||||
"""Sign requests for OAuth 2.0, currently only bearer token is supported."""
|
||||
requires_request_body = True
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
try:
|
||||
url, headers, body = self.prepare(
|
||||
str(request.url), request.headers, request.content)
|
||||
headers['Content-Length'] = str(len(body))
|
||||
yield build_request(url=url, headers=headers, body=body, initial_request=request)
|
||||
except KeyError as error:
|
||||
description = f'Unsupported token_type: {str(error)}'
|
||||
raise UnsupportedTokenTypeError(description=description)
|
||||
|
||||
|
||||
class OAuth2ClientAuth(Auth, ClientAuth):
|
||||
requires_request_body = True
|
||||
|
||||
def auth_flow(self, request: Request) -> typing.Generator[Request, Response, None]:
|
||||
url, headers, body = self.prepare(
|
||||
request.method, str(request.url), request.headers, request.content)
|
||||
headers['Content-Length'] = str(len(body))
|
||||
yield build_request(url=url, headers=headers, body=body, initial_request=request)
|
||||
|
||||
|
||||
class AsyncOAuth2Client(_OAuth2Client, httpx.AsyncClient):
|
||||
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
|
||||
|
||||
client_auth_class = OAuth2ClientAuth
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
|
||||
def __init__(self, client_id=None, client_secret=None,
|
||||
token_endpoint_auth_method=None,
|
||||
revocation_endpoint_auth_method=None,
|
||||
scope=None, redirect_uri=None,
|
||||
token=None, token_placement='header',
|
||||
update_token=None, leeway=60, **kwargs):
|
||||
|
||||
# extract httpx.Client kwargs
|
||||
client_kwargs = self._extract_session_request_params(kwargs)
|
||||
httpx.AsyncClient.__init__(self, **client_kwargs)
|
||||
|
||||
# We use a Lock to synchronize coroutines to prevent
|
||||
# multiple concurrent attempts to refresh the same token
|
||||
self._token_refresh_lock = Lock()
|
||||
|
||||
_OAuth2Client.__init__(
|
||||
self, session=None,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
|
||||
scope=scope, redirect_uri=redirect_uri,
|
||||
token=token, token_placement=token_placement,
|
||||
update_token=update_token, leeway=leeway, **kwargs
|
||||
)
|
||||
|
||||
async def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
await self.ensure_active_token(self.token)
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
return await super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
|
||||
@asynccontextmanager
|
||||
async def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
await self.ensure_active_token(self.token)
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
async with super().stream(
|
||||
method, url, auth=auth, **kwargs) as resp:
|
||||
yield resp
|
||||
|
||||
async def ensure_active_token(self, token):
|
||||
async with self._token_refresh_lock:
|
||||
if self.token.is_expired(leeway=self.leeway):
|
||||
refresh_token = token.get('refresh_token')
|
||||
url = self.metadata.get('token_endpoint')
|
||||
if refresh_token and url:
|
||||
await self.refresh_token(url, refresh_token=refresh_token)
|
||||
elif self.metadata.get('grant_type') == 'client_credentials':
|
||||
access_token = token['access_token']
|
||||
new_token = await self.fetch_token(url, grant_type='client_credentials')
|
||||
if self.update_token:
|
||||
await self.update_token(new_token, access_token=access_token)
|
||||
else:
|
||||
raise InvalidTokenError()
|
||||
|
||||
async def _fetch_token(self, url, body='', headers=None, auth=USE_CLIENT_DEFAULT,
|
||||
method='POST', **kwargs):
|
||||
if method.upper() == 'POST':
|
||||
resp = await self.post(
|
||||
url, data=dict(url_decode(body)), headers=headers,
|
||||
auth=auth, **kwargs)
|
||||
else:
|
||||
if '?' in url:
|
||||
url = '&'.join([url, body])
|
||||
else:
|
||||
url = '?'.join([url, body])
|
||||
resp = await self.get(url, headers=headers, auth=auth, **kwargs)
|
||||
|
||||
for hook in self.compliance_hook['access_token_response']:
|
||||
resp = hook(resp)
|
||||
|
||||
return self.parse_response_token(resp)
|
||||
|
||||
async def _refresh_token(self, url, refresh_token=None, body='',
|
||||
headers=None, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
resp = await self.post(
|
||||
url, data=dict(url_decode(body)), headers=headers,
|
||||
auth=auth, **kwargs)
|
||||
|
||||
for hook in self.compliance_hook['refresh_token_response']:
|
||||
resp = hook(resp)
|
||||
|
||||
token = self.parse_response_token(resp)
|
||||
if 'refresh_token' not in token:
|
||||
self.token['refresh_token'] = refresh_token
|
||||
|
||||
if self.update_token:
|
||||
await self.update_token(self.token, refresh_token=refresh_token)
|
||||
|
||||
return self.token
|
||||
|
||||
def _http_post(self, url, body=None, auth=USE_CLIENT_DEFAULT, headers=None, **kwargs):
|
||||
return self.post(
|
||||
url, data=dict(url_decode(body)),
|
||||
headers=headers, auth=auth, **kwargs)
|
||||
|
||||
|
||||
class OAuth2Client(_OAuth2Client, httpx.Client):
|
||||
SESSION_REQUEST_PARAMS = HTTPX_CLIENT_KWARGS
|
||||
|
||||
client_auth_class = OAuth2ClientAuth
|
||||
token_auth_class = OAuth2Auth
|
||||
oauth_error_class = OAuthError
|
||||
|
||||
def __init__(self, client_id=None, client_secret=None,
|
||||
token_endpoint_auth_method=None,
|
||||
revocation_endpoint_auth_method=None,
|
||||
scope=None, redirect_uri=None,
|
||||
token=None, token_placement='header',
|
||||
update_token=None, **kwargs):
|
||||
|
||||
# extract httpx.Client kwargs
|
||||
client_kwargs = self._extract_session_request_params(kwargs)
|
||||
httpx.Client.__init__(self, **client_kwargs)
|
||||
|
||||
_OAuth2Client.__init__(
|
||||
self, session=self,
|
||||
client_id=client_id, client_secret=client_secret,
|
||||
token_endpoint_auth_method=token_endpoint_auth_method,
|
||||
revocation_endpoint_auth_method=revocation_endpoint_auth_method,
|
||||
scope=scope, redirect_uri=redirect_uri,
|
||||
token=token, token_placement=token_placement,
|
||||
update_token=update_token, **kwargs
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_error(error_type, error_description):
|
||||
raise OAuthError(error_type, error_description)
|
||||
|
||||
def request(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
if not self.ensure_active_token(self.token):
|
||||
raise InvalidTokenError()
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
return super().request(
|
||||
method, url, auth=auth, **kwargs)
|
||||
|
||||
def stream(self, method, url, withhold_token=False, auth=USE_CLIENT_DEFAULT, **kwargs):
|
||||
if not withhold_token and auth is USE_CLIENT_DEFAULT:
|
||||
if not self.token:
|
||||
raise MissingTokenError()
|
||||
|
||||
if not self.ensure_active_token(self.token):
|
||||
raise InvalidTokenError()
|
||||
|
||||
auth = self.token_auth
|
||||
|
||||
return super().stream(
|
||||
method, url, auth=auth, **kwargs)
|
||||
@@ -0,0 +1,30 @@
|
||||
from httpx import Request
|
||||
|
||||
HTTPX_CLIENT_KWARGS = [
|
||||
'headers', 'cookies', 'verify', 'cert', 'http1', 'http2',
|
||||
'proxies', 'timeout', 'follow_redirects', 'limits', 'max_redirects',
|
||||
'event_hooks', 'base_url', 'transport', 'app', 'trust_env',
|
||||
]
|
||||
|
||||
|
||||
def extract_client_kwargs(kwargs):
|
||||
client_kwargs = {}
|
||||
for k in HTTPX_CLIENT_KWARGS:
|
||||
if k in kwargs:
|
||||
client_kwargs[k] = kwargs.pop(k)
|
||||
return client_kwargs
|
||||
|
||||
|
||||
def build_request(url, headers, body, initial_request: Request) -> Request:
|
||||
"""Make sure that all the data from initial request is passed to the updated object"""
|
||||
updated_request = Request(
|
||||
method=initial_request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
content=body
|
||||
)
|
||||
|
||||
if hasattr(initial_request, 'extensions'):
|
||||
updated_request.extensions = initial_request.extensions
|
||||
|
||||
return updated_request
|
||||
Reference in New Issue
Block a user