Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the experiment client #176

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions flask_oauthlib/contrib/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy

from flask import current_app
from werkzeug.local import LocalProxy

from .application import OAuth1Application, OAuth2Application


Expand All @@ -26,7 +29,7 @@ def __init__(self, app=None):

def init_app(self, app):
app.extensions = getattr(app, 'extensions', {})
app.extensions[self.state_key] = self
app.extensions[self.state_key] = OAuthState()

def add_remote_app(self, remote_app, name=None, **kwargs):
"""Adds remote application and applies custom attributes on it.
Expand All @@ -46,6 +49,8 @@ def add_remote_app(self, remote_app, name=None, **kwargs):
remote_app = copy.copy(remote_app)
remote_app.name = name
vars(remote_app).update(kwargs)
if not hasattr(remote_app, 'clients'):
remote_app.clients = cached_clients
self.remote_apps[name] = remote_app
return remote_app

Expand All @@ -62,9 +67,9 @@ def remote_app(self, name, version=None, **kwargs):
else:
version = '2'
if version == '1':
remote_app = OAuth1Application(name)
remote_app = OAuth1Application(name, clients=cached_clients)
elif version == '2':
remote_app = OAuth2Application(name)
remote_app = OAuth2Application(name, clients=cached_clients)
else:
raise ValueError('unkonwn version %r' % version)
return self.add_remote_app(remote_app, **kwargs)
Expand All @@ -80,3 +85,20 @@ def __getattr__(self, key):
if app:
return app
raise AttributeError('No such app: %s' % key)


class OAuthState(object):

def __init__(self):
self.cached_clients = {}


def get_cached_clients():
"""Gets the cached clients dictionary in current context."""
if OAuth.state_key not in current_app.extensions:
raise RuntimeError('%r is not initialized.' % current_app)
state = current_app.extensions[OAuth.state_key]
return state.cached_clients


cached_clients = LocalProxy(get_cached_clients)
126 changes: 97 additions & 29 deletions flask_oauthlib/contrib/client/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import contextlib
import warnings
import functools
try:
from urllib.parse import urljoin
except ImportError:
Expand All @@ -27,13 +28,25 @@


class BaseApplication(object):
"""The base class of OAuth application."""
"""The base class of OAuth application.

An application instance could be used in mupltiple context. It never stores
any session-scope state in the ``__dict__`` of itself.

:param name: the name of this application.
:param clients: optional. a reference to the cached clients dictionary.
"""

session_class = None
endpoint_url = OAuthProperty('endpoint_url', default='')

def __init__(self, name, **kwargs):
def __init__(self, name, clients=None, **kwargs):
# oauth property required
self.name = name

if clients:
self.clients = clients

# other descriptor assignable attributes
for k, v in kwargs.items():
if not hasattr(self.__class__, k):
Expand All @@ -49,6 +62,11 @@ def tokengetter(self, fn):
return fn

def obtain_token(self):
"""Obtains the access token by calling ``tokengetter`` which was
defined by users.

:returns: token or ``None``.
"""
tokengetter = getattr(self, '_tokengetter', None)
if tokengetter is None:
raise RuntimeError('%r missing tokengetter' % self)
Expand All @@ -61,7 +79,24 @@ def client(self):

:returns: The OAuth session instance or ``None`` while token missing.
"""
raise NotImplementedError
token = self.obtain_token()
if token is None:
raise AccessTokenNotFound
return self._make_client_with_token()

def _make_client_with_token(self, token):
"""Uses cached client or create new one with specific token."""
cached_clients = getattr(self, 'clients', None)
hashed_token = _hash_token(self, token)

if cached_clients and hashed_token in cached_clients:
return cached_clients[hashed_token]

client = self.make_client(token) # implemented in subclasses
if cached_clients:
cached_clients[hashed_token] = client

return client

def authorize(self, callback_uri, code=302):
"""Redirects to third-part URL and authorizes.
Expand Down Expand Up @@ -91,25 +126,36 @@ def authorized_response(self):
'patch',
])

# magic: generate methods which forward to self.client
def _make_method(_method_name):
def _method(self, url, *args, **kwargs):
url = urljoin(self.endpoint_url, url)
return getattr(self.client, _method_name)(url, *args, **kwargs)
return _method
for _method_name in forward_methods:
_method = _make_method(_method_name)
_method.func_name = _method.__name__ = _method_name
locals()[_method_name] = _method
del _make_method
del _method
del _method_name
def request(self, method, url, token=None, *args, **kwargs):
if token is None:
client = self.client
else:
client = self._make_client_with_token(token)
url = urljoin(self.endpoint_url, url)
return getattr(client, method)(url, *args, **kwargs)

def head(self, *args, **kwargs):
return self.request('head', *args, **kwargs)

def get(self, *args, **kwargs):
return self.request('get', *args, **kwargs)

def post(self, *args, **kwargs):
return self.request('post', *args, **kwargs)

def put(self, *args, **kwargs):
return self.request('put', *args, **kwargs)

def delete(self, *args, **kwargs):
return self.request('delete', *args, **kwargs)

def patch(self, *args, **kwargs):
return self.request('patch', *args, **kwargs)


class OAuth1Application(BaseApplication):
"""The remote application for OAuth 1.0a."""

endpoint_url = OAuthProperty('endpoint_url', default='')
request_token_url = OAuthProperty('request_token_url')
access_token_url = OAuthProperty('access_token_url')
authorization_url = OAuthProperty('authorization_url')
Expand All @@ -121,12 +167,19 @@ class OAuth1Application(BaseApplication):

_session_request_token = WebSessionData('req_token')

@property
def client(self):
token = self.obtain_token()
if token is None:
raise AccessTokenNotFound
access_token, access_token_secret = token
def make_client(self, token):
"""Creates a client with specific access token pair.

:param token: a tuple of access token pair ``(token, token_secret)``
or a dictionary of access token response.
:returns: a :class:`requests_oauthlib.oauth1_session.OAuth1Session`
object.
"""
if isinstance(token, dict):
access_token = token['token']
access_token_secret = token['token_secret']
else:
access_token, access_token_secret = token
return self.make_oauth_session(
resource_owner_key=access_token,
resource_owner_secret=access_token_secret)
Expand Down Expand Up @@ -183,7 +236,6 @@ class OAuth2Application(BaseApplication):

session_class = OAuth2Session

endpoint_url = OAuthProperty('endpoint_url', default='')
access_token_url = OAuthProperty('access_token_url')
authorization_url = OAuthProperty('authorization_url')
refresh_token_url = OAuthProperty('refresh_token_url', default='')
Expand All @@ -197,11 +249,13 @@ class OAuth2Application(BaseApplication):
_session_state = WebSessionData('state')
_session_redirect_url = WebSessionData('redir')

@property
def client(self):
token = self.obtain_token()
if token is None:
raise AccessTokenNotFound
def make_client(self, token):
"""Creates a client with specific access token dictionary.

:param token: a dictionary of access token response.
:returns: a :class:`requests_oauthlib.oauth2_session.OAuth2Session`
object.
"""
return self.session_class(self.client_id, token=token)

def tokensaver(self, fn):
Expand Down Expand Up @@ -294,3 +348,17 @@ def insecure_transport(self):
' It may put you in danger of the Man-in-the-middle attack'
' while using OAuth 2.', RuntimeWarning)
yield


def _hash_token(application, token):
"""Creates a hashable object for given token then we could use it as a
dictionary key.
"""
if isinstance(token, dict):
hashed_token = tuple(sorted(token.items()))
elif isinstance(token, tuple):
hashed_token = token
else:
raise TypeError('%r is unknown type of token' % token)

return (application.__class__.__name__, application.name, hashed_token)