venv added, updated
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
# flake8: noqa
|
||||
|
||||
from ..base_client import BaseOAuth, OAuthError
|
||||
from .integration import StarletteIntegration
|
||||
from .apps import StarletteOAuth1App, StarletteOAuth2App
|
||||
|
||||
|
||||
class OAuth(BaseOAuth):
|
||||
oauth1_client_cls = StarletteOAuth1App
|
||||
oauth2_client_cls = StarletteOAuth2App
|
||||
framework_integration_cls = StarletteIntegration
|
||||
|
||||
def __init__(self, config=None, cache=None, fetch_token=None, update_token=None):
|
||||
super().__init__(
|
||||
cache=cache, fetch_token=fetch_token, update_token=update_token)
|
||||
self.config = config
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OAuth', 'OAuthError',
|
||||
'StarletteIntegration', 'StarletteOAuth1App', 'StarletteOAuth2App',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from starlette.datastructures import URL
|
||||
from starlette.responses import RedirectResponse
|
||||
from ..base_client import OAuthError
|
||||
from ..base_client import BaseApp
|
||||
from ..base_client.async_app import AsyncOAuth1Mixin, AsyncOAuth2Mixin
|
||||
from ..base_client.async_openid import AsyncOpenIDMixin
|
||||
from ..httpx_client import AsyncOAuth1Client, AsyncOAuth2Client
|
||||
|
||||
|
||||
class StarletteAppMixin:
|
||||
async def save_authorize_data(self, request, **kwargs):
|
||||
state = kwargs.pop('state', None)
|
||||
if state:
|
||||
if self.framework.cache:
|
||||
session = None
|
||||
else:
|
||||
session = request.session
|
||||
await self.framework.set_state_data(session, state, kwargs)
|
||||
else:
|
||||
raise RuntimeError('Missing state value')
|
||||
|
||||
async def authorize_redirect(self, request, redirect_uri=None, **kwargs):
|
||||
"""Create a HTTP Redirect for Authorization Endpoint.
|
||||
|
||||
:param request: HTTP request instance from Starlette view.
|
||||
:param redirect_uri: Callback or redirect URI for authorization.
|
||||
:param kwargs: Extra parameters to include.
|
||||
:return: A HTTP redirect response.
|
||||
"""
|
||||
|
||||
# Handle Starlette >= 0.26.0 where redirect_uri may now be a URL and not a string
|
||||
if redirect_uri and isinstance(redirect_uri, URL):
|
||||
redirect_uri = str(redirect_uri)
|
||||
rv = await self.create_authorization_url(redirect_uri, **kwargs)
|
||||
await self.save_authorize_data(request, redirect_uri=redirect_uri, **rv)
|
||||
return RedirectResponse(rv['url'], status_code=302)
|
||||
|
||||
|
||||
class StarletteOAuth1App(StarletteAppMixin, AsyncOAuth1Mixin, BaseApp):
|
||||
client_cls = AsyncOAuth1Client
|
||||
|
||||
async def authorize_access_token(self, request, **kwargs):
|
||||
params = dict(request.query_params)
|
||||
state = params.get('oauth_token')
|
||||
if not state:
|
||||
raise OAuthError(description='Missing "oauth_token" parameter')
|
||||
|
||||
data = await self.framework.get_state_data(request.session, state)
|
||||
if not data:
|
||||
raise OAuthError(description='Missing "request_token" in temporary data')
|
||||
|
||||
params['request_token'] = data['request_token']
|
||||
params.update(kwargs)
|
||||
await self.framework.clear_state_data(request.session, state)
|
||||
return await self.fetch_access_token(**params)
|
||||
|
||||
|
||||
class StarletteOAuth2App(StarletteAppMixin, AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp):
|
||||
client_cls = AsyncOAuth2Client
|
||||
|
||||
async def authorize_access_token(self, request, **kwargs):
|
||||
error = request.query_params.get('error')
|
||||
if error:
|
||||
description = request.query_params.get('error_description')
|
||||
raise OAuthError(error=error, description=description)
|
||||
|
||||
params = {
|
||||
'code': request.query_params.get('code'),
|
||||
'state': request.query_params.get('state'),
|
||||
}
|
||||
|
||||
if self.framework.cache:
|
||||
session = None
|
||||
else:
|
||||
session = request.session
|
||||
|
||||
claims_options = kwargs.pop('claims_options', None)
|
||||
state_data = await self.framework.get_state_data(session, params.get('state'))
|
||||
await self.framework.clear_state_data(session, params.get('state'))
|
||||
params = self._format_state_params(state_data, params)
|
||||
token = await self.fetch_access_token(**params, **kwargs)
|
||||
|
||||
if 'id_token' in token and 'nonce' in state_data:
|
||||
userinfo = await self.parse_id_token(token, nonce=state_data['nonce'], claims_options=claims_options)
|
||||
token['userinfo'] = userinfo
|
||||
return token
|
||||
@@ -0,0 +1,71 @@
|
||||
import json
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Hashable,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from ..base_client import FrameworkIntegration
|
||||
|
||||
|
||||
class StarletteIntegration(FrameworkIntegration):
|
||||
async def _get_cache_data(self, key: Hashable):
|
||||
value = await self.cache.get(key)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def get_state_data(self, session: Optional[Dict[str, Any]], state: str) -> Dict[str, Any]:
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
value = await self._get_cache_data(key)
|
||||
elif session is not None:
|
||||
value = session.get(key)
|
||||
else:
|
||||
value = None
|
||||
|
||||
if value:
|
||||
return value.get('data')
|
||||
return None
|
||||
|
||||
async def set_state_data(self, session: Optional[Dict[str, Any]], state: str, data: Any):
|
||||
key_prefix = f'_state_{self.name}_'
|
||||
key = f'{key_prefix}{state}'
|
||||
if self.cache:
|
||||
await self.cache.set(key, json.dumps({'data': data}), self.expires_in)
|
||||
elif session is not None:
|
||||
# clear old state data to avoid session size growing
|
||||
for old_key in list(session.keys()):
|
||||
if old_key.startswith(key_prefix):
|
||||
session.pop(old_key)
|
||||
now = time.time()
|
||||
session[key] = {'data': data, 'exp': now + self.expires_in}
|
||||
|
||||
async def clear_state_data(self, session: Optional[Dict[str, Any]], state: str):
|
||||
key = f'_state_{self.name}_{state}'
|
||||
if self.cache:
|
||||
await self.cache.delete(key)
|
||||
elif session is not None:
|
||||
session.pop(key, None)
|
||||
self._clear_session_state(session)
|
||||
|
||||
def update_token(self, token, refresh_token=None, access_token=None):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load_config(oauth, name, params):
|
||||
if not oauth.config:
|
||||
return {}
|
||||
|
||||
rv = {}
|
||||
for k in params:
|
||||
conf_key = f'{name}_{k}'.upper()
|
||||
v = oauth.config.get(conf_key, default=None)
|
||||
if v is not None:
|
||||
rv[k] = v
|
||||
return rv
|
||||
Reference in New Issue
Block a user