Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CORS in WebMessageHandler #106

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 114 additions & 2 deletions brubeck/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,31 @@ def coro_spawn(function, app, message, *a, **kw):
import base64
import hmac
import cPickle as pickle
import functools
from itertools import chain
import os, sys
from dictshield.base import ShieldException
from request import Request, to_bytes, to_unicode

import ujson as json


###
### Decorators
###
def cors(method):
"""Decorate request handler methods with this to allow CORS requests to
use them."""
WebMessageHandler.cors_allow_methods.add(method.__name__.upper())
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
# quit early to avoid processing invalid requests.
if self.cors_request() == False:
return self.render_error(self._FORBIDDEN, self.cors_error)
return method(self, *args, **kwargs)
return wrapper


###
### Common helpers
###
Expand Down Expand Up @@ -356,6 +374,23 @@ class WebMessageHandler(MessageHandler):
"""A base class for common functionality in a request handler.

Tornado's design inspired this design.

This handler defines some attributes that can be set to enable CORS
requests support:

`cors_allow_origin`
The set of hosts allowed to make CORS requests to this handler.
`cors_allow_methods`
The set of HTTP methods that are exposed to CORS requests.
Use the `@cors` decorator on the methods to expose and they'll be
added to this list automatically.
`cors_allow_credentials`
Whether the service accepts cookies and auth set on the request.
`cors_allow_headers`
The set of headers the client is allowed to send in the request.
`cors_expose_headers`
The set of response headers you want the client to have access to.

"""
_DEFAULT_STATUS = 500 # default to server error
_SUCCESS_CODE = 200
Expand Down Expand Up @@ -401,13 +436,90 @@ def set_body(self, body, headers=None, status_code=_SUCCESS_CODE):
self.headers = headers

###
### Supported HTTP request methods are mapped to these functions
### CORS support
###
# We whitelist some simple headers by default
# to work with Webkit browsers implementation.

cors_allow_origin = set()
cors_allow_methods = set()
cors_allow_headers = set(('Accept', 'Authorization', 'Origin'))
cors_expose_headers = set()
cors_allow_credentials = False

def cors_verify_origin(self, origin):
allowed = self.cors_allow_origin
return origin and ('*' in allowed or origin in allowed)

def cors_verify_method(self, request_method):
"""Case-sensitive match of method name."""
return request_method in self.cors_allow_methods

def cors_verify_headers(self, fields):
"""Case-insensitive match of header field names."""
allowed = set(h.lower() for h in self.cors_allow_headers)
return all(f.lower() in allowed for f in fields)

def cors_preflight(self):
"""Handle CORS preflight"""
request_headers = self.message.headers
origin = request_headers.get('Origin')
request_method = request_headers.get('Access-Control-Request-Method')
field_names = request_headers.get('Access-Control-Request-Headers', '')
field_names = [f.strip() for f in field_names.split(',') if field_names]
# validate headers
if (self.cors_verify_origin(origin) and
self.cors_verify_method(request_method) and
self.cors_verify_headers(field_names)):
# set response headers
self.headers['Access-Control-Allow-Origin'] = origin
self.headers['Access-Control-Allow-Methods'] = str.join(', ',
self.cors_allow_methods)
self.headers['Access-Control-Allow-Headers'] = str.join(', ',
self.cors_allow_headers)
if self.cors_allow_credentials:
self.headers['Access-Control-Allow-Credentials'] = 'true'
elif '*' in self.cors_allow_origin:
# only non-credential request allows response with wildcard
self.headers['Access-Control-Allow-Origin'] = '*'
return True

def cors_request(self):
"""Handle CORS request"""
origin = self.message.headers.get('Origin')
# ignore non-CORS requests
if not origin:
return
# handle preflight
if self.message.method.lower() == 'options':
return self.cors_preflight()
# validate origin
if not self.cors_verify_origin(origin):
return False
# set response headers
self.headers['Access-Control-Allow-Origin'] = origin
if self.cors_allow_credentials:
self.headers['Access-Control-Allow-Credentials'] = 'true'
elif '*' in self.cors_allow_origin:
# only non-credential request allows response with wildcard
self.headers['Access-Control-Allow-Origin'] = '*'
if self.cors_expose_headers:
self.headers['Access-Control-Expose-Headers'] = str.join(', ',
self.cors_expose_headers)
return True

def cors_error(self):
"""Handler for incorrect CORS requests"""
self.set_status(403, status_msg='Invalid CORS request')

###
### Supported HTTP request methods are mapped to these functions
###
def options(self, *args, **kwargs):
"""Default to allowing all of the methods you have defined and public
"""
self.headers["Access-Control-Allow-Methods"] = self.supported_methods
methods = str.join(', ', map(str.upper, self.supported_methods))
self.headers['Allow'] = methods
self.set_status(200)
return self.render()

Expand Down