forked from lepture/flask-oauthlib
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
452 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
# coding: utf-8 | ||
|
||
import os | ||
import unittest | ||
from datetime import datetime, timedelta | ||
from flask import Flask | ||
from flask import g, render_template, request, jsonify, make_response | ||
from flask_sqlalchemy import SQLAlchemy | ||
from sqlalchemy.orm import relationship | ||
from flask_oauthlib.provider import OAuth2Provider | ||
from flask_oauthlib.contrib.oauth2 import bind_sqlalchemy | ||
from flask_oauthlib.contrib.oauth2 import bind_cache_grant | ||
|
||
os.environ['OAUTHLIB_INSECURE_TRANSPORT'] = 'true' | ||
|
||
db = SQLAlchemy() | ||
|
||
|
||
class User(db.Model): | ||
id = db.Column(db.Integer, primary_key=True) | ||
username = db.Column(db.String(40), unique=True, index=True, | ||
nullable=False) | ||
|
||
def check_password(self, password): | ||
return True | ||
|
||
|
||
class Client(db.Model): | ||
# id = db.Column(db.Integer, primary_key=True) | ||
# human readable name | ||
name = db.Column(db.String(40)) | ||
client_id = db.Column(db.String(40), primary_key=True) | ||
client_secret = db.Column(db.String(55), unique=True, index=True, | ||
nullable=False) | ||
is_confidential = db.Column(db.Boolean, default=False) | ||
_redirect_uris = db.Column(db.Text) | ||
default_scope = db.Column(db.Text, default='email address') | ||
|
||
@property | ||
def client_type(self): | ||
if self.is_confidential: | ||
return 'confidential' | ||
return 'public' | ||
|
||
@property | ||
def user(self): | ||
return User.query.get(1) | ||
|
||
@property | ||
def redirect_uris(self): | ||
if self._redirect_uris: | ||
return self._redirect_uris.split() | ||
return [] | ||
|
||
@property | ||
def default_redirect_uri(self): | ||
return self.redirect_uris[0] | ||
|
||
@property | ||
def default_scopes(self): | ||
if self.default_scope: | ||
return self.default_scope.split() | ||
return [] | ||
|
||
@property | ||
def allowed_grant_types(self): | ||
return ['authorization_code', 'password', 'client_credentials', | ||
'refresh_token'] | ||
|
||
|
||
class Grant(db.Model): | ||
id = db.Column(db.Integer, primary_key=True) | ||
user_id = db.Column( | ||
db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') | ||
) | ||
user = relationship('User') | ||
|
||
client_id = db.Column( | ||
db.String(40), db.ForeignKey('client.client_id', ondelete='CASCADE'), | ||
nullable=False, | ||
) | ||
client = relationship('Client') | ||
code = db.Column(db.String(255), index=True, nullable=False) | ||
|
||
redirect_uri = db.Column(db.String(255)) | ||
scope = db.Column(db.Text) | ||
expires = db.Column(db.DateTime) | ||
|
||
def delete(self): | ||
db.session.delete(self) | ||
db.session.commit() | ||
return self | ||
|
||
@property | ||
def scopes(self): | ||
if self.scope: | ||
return self.scope.split() | ||
return None | ||
|
||
|
||
class Token(db.Model): | ||
id = db.Column(db.Integer, primary_key=True) | ||
client_id = db.Column( | ||
db.String(40), db.ForeignKey('client.client_id', ondelete='CASCADE'), | ||
nullable=False, | ||
) | ||
user_id = db.Column( | ||
db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') | ||
) | ||
user = relationship('User') | ||
client = relationship('Client') | ||
token_type = db.Column(db.String(40)) | ||
access_token = db.Column(db.String(255)) | ||
refresh_token = db.Column(db.String(255)) | ||
expires = db.Column(db.DateTime) | ||
scope = db.Column(db.Text) | ||
|
||
def __init__(self, **kwargs): | ||
expires_in = kwargs.pop('expires_in') | ||
self.expires = datetime.utcnow() + timedelta(seconds=expires_in) | ||
for k, v in kwargs.items(): | ||
setattr(self, k, v) | ||
|
||
@property | ||
def scopes(self): | ||
if self.scope: | ||
return self.scope.split() | ||
return [] | ||
|
||
def delete(self): | ||
db.session.delete(self) | ||
db.session.commit() | ||
return self | ||
|
||
|
||
def current_user(): | ||
return g.user | ||
|
||
|
||
def cache_provider(app): | ||
oauth = OAuth2Provider(app) | ||
|
||
bind_sqlalchemy(oauth, db.session, user=User, | ||
token=Token, client=Client) | ||
|
||
app.config.update({'OAUTH2_CACHE_TYPE': 'simple'}) | ||
bind_cache_grant(app, oauth, current_user) | ||
return oauth | ||
|
||
|
||
def sqlalchemy_provider(app): | ||
oauth = OAuth2Provider(app) | ||
|
||
bind_sqlalchemy(oauth, db.session, user=User, token=Token, | ||
client=Client, grant=Grant, current_user=current_user) | ||
|
||
return oauth | ||
|
||
|
||
def default_provider(app): | ||
oauth = OAuth2Provider(app) | ||
|
||
@oauth.clientgetter | ||
def get_client(client_id): | ||
return Client.query.filter_by(client_id=client_id).first() | ||
|
||
@oauth.grantgetter | ||
def get_grant(client_id, code): | ||
return Grant.query.filter_by(client_id=client_id, code=code).first() | ||
|
||
@oauth.tokengetter | ||
def get_token(access_token=None, refresh_token=None): | ||
if access_token: | ||
return Token.query.filter_by(access_token=access_token).first() | ||
if refresh_token: | ||
return Token.query.filter_by(refresh_token=refresh_token).first() | ||
return None | ||
|
||
@oauth.grantsetter | ||
def set_grant(client_id, code, request, *args, **kwargs): | ||
expires = datetime.utcnow() + timedelta(seconds=100) | ||
grant = Grant( | ||
client_id=client_id, | ||
code=code['code'], | ||
redirect_uri=request.redirect_uri, | ||
scope=' '.join(request.scopes), | ||
user_id=g.user.id, | ||
expires=expires, | ||
) | ||
db.session.add(grant) | ||
db.session.commit() | ||
|
||
@oauth.tokensetter | ||
def set_token(token, request, *args, **kwargs): | ||
# In real project, a token is unique bound to user and client. | ||
# Which means, you don't need to create a token every time. | ||
tok = Token(**token) | ||
tok.user_id = request.user.id | ||
tok.client_id = request.client.client_id | ||
db.session.add(tok) | ||
db.session.commit() | ||
|
||
@oauth.usergetter | ||
def get_user(username, password, *args, **kwargs): | ||
# This is optional, if you don't need password credential | ||
# there is no need to implement this method | ||
return User.query.filter_by(username=username).first() | ||
|
||
return oauth | ||
|
||
|
||
def create_server(app, oauth=None): | ||
if not oauth: | ||
oauth = default_provider(app) | ||
|
||
@app.before_request | ||
def load_current_user(): | ||
user = User.query.get(1) | ||
g.user = user | ||
|
||
@app.route('/home') | ||
def home(): | ||
return render_template('home.html') | ||
|
||
@app.route('/oauth/authorize', methods=['GET', 'POST']) | ||
@oauth.authorize_handler | ||
def authorize(*args, **kwargs): | ||
# NOTICE: for real project, you need to require login | ||
if request.method == 'GET': | ||
# render a page for user to confirm the authorization | ||
return 'confirm page' | ||
|
||
if request.method == 'HEAD': | ||
# if HEAD is supported properly, request parameters like | ||
# client_id should be validated the same way as for 'GET' | ||
response = make_response('', 200) | ||
response.headers['X-Client-ID'] = kwargs.get('client_id') | ||
return response | ||
|
||
confirm = request.form.get('confirm', 'no') | ||
return confirm == 'yes' | ||
|
||
@app.route('/oauth/token', methods=['POST', 'GET']) | ||
@oauth.token_handler | ||
def access_token(): | ||
return {} | ||
|
||
@app.route('/oauth/revoke', methods=['POST']) | ||
@oauth.revoke_handler | ||
def revoke_token(): | ||
return {} | ||
|
||
@app.route('/api/email') | ||
@oauth.require_oauth('email') | ||
def email_api(): | ||
oauth = request.oauth | ||
return jsonify(email='me@oauth.net', username=oauth.user.username) | ||
|
||
@app.route('/api/client') | ||
@oauth.require_oauth() | ||
def client_api(): | ||
oauth = request.oauth | ||
return jsonify(client=oauth.client.name) | ||
|
||
@app.route('/api/address/<city>') | ||
@oauth.require_oauth('address') | ||
def address_api(city): | ||
oauth = request.oauth | ||
return jsonify(address=city, username=oauth.user.username) | ||
|
||
@app.route('/api/method', methods=['GET', 'POST', 'PUT', 'DELETE']) | ||
@oauth.require_oauth() | ||
def method_api(): | ||
return jsonify(method=request.method) | ||
|
||
@oauth.invalid_response | ||
def require_oauth_invalid(req): | ||
return jsonify(message=req.error_message), 401 | ||
|
||
return app | ||
|
||
|
||
class TestCase(unittest.TestCase): | ||
def setUp(self): | ||
app = self.create_app() | ||
|
||
app.testing = True | ||
self._ctx = app.app_context() | ||
self._ctx.push() | ||
|
||
db.init_app(app) | ||
db.create_all() | ||
|
||
self.app = app | ||
self.client = app.test_client() | ||
self.prepare_data() | ||
|
||
def tearDown(self): | ||
db.drop_all() | ||
self._ctx.pop() | ||
|
||
def prepare_data(self): | ||
return True | ||
|
||
def create_app(self): | ||
app = Flask(__name__) | ||
app.debug = True | ||
app.secret_key = 'testing' | ||
app.config.update({ | ||
'SQLALCHEMY_DATABASE_URI': 'sqlite://' | ||
}) | ||
return app |
Oops, something went wrong.