Skip to content

Commit

Permalink
Refactor testing providers
Browse files Browse the repository at this point in the history
  • Loading branch information
lepture committed Apr 2, 2015
1 parent e82c0b4 commit 4914e46
Show file tree
Hide file tree
Showing 3 changed files with 452 additions and 0 deletions.
Empty file added tests/test_oauth2/__init__.py
Empty file.
312 changes: 312 additions & 0 deletions tests/test_oauth2/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
# coding: utf-8

import os
import unittest
from datetime import datetime, timedelta
from flask import Flask
from flask import g, render_template, request, jsonify, make_response
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import relationship
from flask_oauthlib.provider import OAuth2Provider
from flask_oauthlib.contrib.oauth2 import bind_sqlalchemy
from flask_oauthlib.contrib.oauth2 import bind_cache_grant

os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = 'true'

db = SQLAlchemy()


class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(40), unique=True, index=True,
nullable=False)

def check_password(self, password):
return True


class Client(db.Model):
# id = db.Column(db.Integer, primary_key=True)
# human readable name
name = db.Column(db.String(40))
client_id = db.Column(db.String(40), primary_key=True)
client_secret = db.Column(db.String(55), unique=True, index=True,
nullable=False)
is_confidential = db.Column(db.Boolean, default=False)
_redirect_uris = db.Column(db.Text)
default_scope = db.Column(db.Text, default='email address')

@property
def client_type(self):
if self.is_confidential:
return 'confidential'
return 'public'

@property
def user(self):
return User.query.get(1)

@property
def redirect_uris(self):
if self._redirect_uris:
return self._redirect_uris.split()
return []

@property
def default_redirect_uri(self):
return self.redirect_uris[0]

@property
def default_scopes(self):
if self.default_scope:
return self.default_scope.split()
return []

@property
def allowed_grant_types(self):
return ['authorization_code', 'password', 'client_credentials',
'refresh_token']


class Grant(db.Model):
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(
db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
)
user = relationship('User')

client_id = db.Column(
db.String(40), db.ForeignKey('client.client_id', ondelete='CASCADE'),
nullable=False,
)
client = relationship('Client')
code = db.Column(db.String(255), index=True, nullable=False)

redirect_uri = db.Column(db.String(255))
scope = db.Column(db.Text)
expires = db.Column(db.DateTime)

def delete(self):
db.session.delete(self)
db.session.commit()
return self

@property
def scopes(self):
if self.scope:
return self.scope.split()
return None


class Token(db.Model):
id = db.Column(db.Integer, primary_key=True)
client_id = db.Column(
db.String(40), db.ForeignKey('client.client_id', ondelete='CASCADE'),
nullable=False,
)
user_id = db.Column(
db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
)
user = relationship('User')
client = relationship('Client')
token_type = db.Column(db.String(40))
access_token = db.Column(db.String(255))
refresh_token = db.Column(db.String(255))
expires = db.Column(db.DateTime)
scope = db.Column(db.Text)

def __init__(self, **kwargs):
expires_in = kwargs.pop('expires_in')
self.expires = datetime.utcnow() + timedelta(seconds=expires_in)
for k, v in kwargs.items():
setattr(self, k, v)

@property
def scopes(self):
if self.scope:
return self.scope.split()
return []

def delete(self):
db.session.delete(self)
db.session.commit()
return self


def current_user():
return g.user


def cache_provider(app):
oauth = OAuth2Provider(app)

bind_sqlalchemy(oauth, db.session, user=User,
token=Token, client=Client)

app.config.update({'OAUTH2_CACHE_TYPE': 'simple'})
bind_cache_grant(app, oauth, current_user)
return oauth


def sqlalchemy_provider(app):
oauth = OAuth2Provider(app)

bind_sqlalchemy(oauth, db.session, user=User, token=Token,
client=Client, grant=Grant, current_user=current_user)

return oauth


def default_provider(app):
oauth = OAuth2Provider(app)

@oauth.clientgetter
def get_client(client_id):
return Client.query.filter_by(client_id=client_id).first()

@oauth.grantgetter
def get_grant(client_id, code):
return Grant.query.filter_by(client_id=client_id, code=code).first()

@oauth.tokengetter
def get_token(access_token=None, refresh_token=None):
if access_token:
return Token.query.filter_by(access_token=access_token).first()
if refresh_token:
return Token.query.filter_by(refresh_token=refresh_token).first()
return None

@oauth.grantsetter
def set_grant(client_id, code, request, *args, **kwargs):
expires = datetime.utcnow() + timedelta(seconds=100)
grant = Grant(
client_id=client_id,
code=code['code'],
redirect_uri=request.redirect_uri,
scope=' '.join(request.scopes),
user_id=g.user.id,
expires=expires,
)
db.session.add(grant)
db.session.commit()

@oauth.tokensetter
def set_token(token, request, *args, **kwargs):
# In real project, a token is unique bound to user and client.
# Which means, you don't need to create a token every time.
tok = Token(**token)
tok.user_id = request.user.id
tok.client_id = request.client.client_id
db.session.add(tok)
db.session.commit()

@oauth.usergetter
def get_user(username, password, *args, **kwargs):
# This is optional, if you don't need password credential
# there is no need to implement this method
return User.query.filter_by(username=username).first()

return oauth


def create_server(app, oauth=None):
if not oauth:
oauth = default_provider(app)

@app.before_request
def load_current_user():
user = User.query.get(1)
g.user = user

@app.route('/home')
def home():
return render_template('home.html')

@app.route('/oauth/authorize', methods=['GET', 'POST'])
@oauth.authorize_handler
def authorize(*args, **kwargs):
# NOTICE: for real project, you need to require login
if request.method == 'GET':
# render a page for user to confirm the authorization
return 'confirm page'

if request.method == 'HEAD':
# if HEAD is supported properly, request parameters like
# client_id should be validated the same way as for 'GET'
response = make_response('', 200)
response.headers['X-Client-ID'] = kwargs.get('client_id')
return response

confirm = request.form.get('confirm', 'no')
return confirm == 'yes'

@app.route('/oauth/token', methods=['POST', 'GET'])
@oauth.token_handler
def access_token():
return {}

@app.route('/oauth/revoke', methods=['POST'])
@oauth.revoke_handler
def revoke_token():
return {}

@app.route('/api/email')
@oauth.require_oauth('email')
def email_api():
oauth = request.oauth
return jsonify(email='me@oauth.net', username=oauth.user.username)

@app.route('/api/client')
@oauth.require_oauth()
def client_api():
oauth = request.oauth
return jsonify(client=oauth.client.name)

@app.route('/api/address/<city>')
@oauth.require_oauth('address')
def address_api(city):
oauth = request.oauth
return jsonify(address=city, username=oauth.user.username)

@app.route('/api/method', methods=['GET', 'POST', 'PUT', 'DELETE'])
@oauth.require_oauth()
def method_api():
return jsonify(method=request.method)

@oauth.invalid_response
def require_oauth_invalid(req):
return jsonify(message=req.error_message), 401

return app


class TestCase(unittest.TestCase):
def setUp(self):
app = self.create_app()

app.testing = True
self._ctx = app.app_context()
self._ctx.push()

db.init_app(app)
db.create_all()

self.app = app
self.client = app.test_client()
self.prepare_data()

def tearDown(self):
db.drop_all()
self._ctx.pop()

def prepare_data(self):
return True

def create_app(self):
app = Flask(__name__)
app.debug = True
app.secret_key = 'testing'
app.config.update({
'SQLALCHEMY_DATABASE_URI': 'sqlite://'
})
return app
Loading

0 comments on commit 4914e46

Please sign in to comment.