Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/url_redirections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@ async def redirect_handler(response, **server):
return ''


@app.route('/redirect')
async def redirect_helper_handler(response, **server):
# or use a helper. it's an instance of `tremolo.exceptions.HTTPRedirect`
raise response.redirect('http://example.com/', code=301)


if __name__ == '__main__':
app.run('0.0.0.0', 8000)
8 changes: 8 additions & 0 deletions tests/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ async def timeouts(request, response):
await asyncio.sleep(10)


@app.route('/redirect')
async def redirect(request, response):
if request.query_string:
raise response.redirect('/new', code=int(request.query_string))

raise response.redirect('/new')


@app.route('/reload')
async def reload(request, **server):
assert server != {}
Expand Down
12 changes: 6 additions & 6 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,24 @@ def test_from_requesttimeout_override_args(self):
self.assertEqual(b.args, ('bar',))

def test_from_methodnotallowed(self):
a = MethodNotAllowed('foo', methods=(b'GET',))
a = MethodNotAllowed('foo', allow=b'GET')
b = HTTPException(cause=a)

self.assertTrue(b is a)
self.assertEqual(b.__class__, MethodNotAllowed)
self.assertEqual(b.__cause__, None)
self.assertEqual(b.args, ('foo',))
self.assertEqual(b.methods, (b'GET',))
self.assertEqual(b.options['allow'], b'GET')

def test_from_methodnotallowed_override_args_methods(self):
a = MethodNotAllowed('foo', methods=(b'GET',))
b = HTTPException('bar', cause=a, methods=(b'POST',))
def test_from_methodnotallowed_override_args_options(self):
a = MethodNotAllowed('foo', allow=b'GET')
b = HTTPException('bar', cause=a, allow=b'POST')

self.assertTrue(b is a)
self.assertEqual(b.__class__, MethodNotAllowed)
self.assertEqual(b.__cause__, None)
self.assertEqual(b.args, ('bar',))
self.assertEqual(b.methods, (b'POST',))
self.assertEqual(b.options['allow'], b'POST')

def test_to_forbidden(self):
a = ValueError('foo')
Expand Down
35 changes: 34 additions & 1 deletion tests/test_http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,40 @@ def test_class_post_methodnotallowed(self):

self.assertEqual(header[:header.find(b'\r\n')],
b'HTTP/1.1 405 Method Not Allowed')
self.assertTrue(b'\r\nAllow: GET' in header)
self.assertTrue(b'\r\nallow: GET' in header)

def test_redirect(self):
header, body = getcontents(host=HTTP_HOST,
port=HTTP_PORT,
method='GET',
url='/redirect?303',
version='1.0')

self.assertEqual(header[:header.find(b'\r\n')],
b'HTTP/1.0 302 Found')
self.assertTrue(b'\r\nlocation: /new' in header)

def test_redirect_301(self):
header, body = getcontents(host=HTTP_HOST,
port=HTTP_PORT,
method='GET',
url='/redirect?301',
version='1.0')

self.assertEqual(header[:header.find(b'\r\n')],
b'HTTP/1.0 301 Moved Permanently')
self.assertTrue(b'\r\nlocation: /new' in header)

def test_redirect_303(self):
header, body = getcontents(host=HTTP_HOST,
port=HTTP_PORT,
method='GET',
url='/redirect?303',
version='1.1')

self.assertEqual(header[:header.find(b'\r\n')],
b'HTTP/1.1 303 See Other')
self.assertTrue(b'\r\nlocation: /new' in header)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions tremolo/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .lib.http_exceptions import ( # noqa: F401
TremoloException,
HTTPException,
HTTPRedirect,
BadRequest,
Unauthorized,
Forbidden,
Expand Down
2 changes: 1 addition & 1 deletion tremolo/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,6 @@ async def request_received(self, request, response):
key = -1

if methods:
raise MethodNotAllowed(methods=methods)
raise MethodNotAllowed(allow=b', '.join(methods))

raise NotFound
21 changes: 11 additions & 10 deletions tremolo/lib/http_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __new__(cls, *args, cause=None, **kwargs):
return super().__new__(cls)

def __init__(self, *args, code=None, message=None, content_type=None,
cause=None):
cause=None, **options):
if isinstance(code, int):
self.code = code

Expand All @@ -45,13 +45,17 @@ def __init__(self, *args, code=None, message=None, content_type=None,
self.content_type = content_type

if isinstance(cause, Exception):
if cause is not self:
if cause is self:
if not options:
options = self.options
else:
self.__cause__ = cause

if cause.args and not args:
args = cause.args

self.args = args
self.options = options

@property
def encoding(self):
Expand All @@ -62,6 +66,11 @@ def encoding(self):
return 'utf-8'


class HTTPRedirect(HTTPException):
code = 302
message = 'Found'


class BadRequest(HTTPException):
code = 400
message = 'Bad Request'
Expand All @@ -86,14 +95,6 @@ class MethodNotAllowed(HTTPException):
code = 405
message = 'Method Not Allowed'

def __init__(self, *args, methods=(), **kwargs):
super().__init__(*args, **kwargs)

if kwargs.get('cause') is self and not methods:
methods = self.methods

self.methods = methods


class RequestTimeout(HTTPException):
code = 408
Expand Down
20 changes: 15 additions & 5 deletions tremolo/lib/http_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from tremolo.utils import parse_fields
from .http_exceptions import (
HTTPException,
HTTPRedirect,
BadRequest,
InternalServerError,
MethodNotAllowed,
RangeNotSatisfiable,
WebSocketException,
WebSocketServerClosed
Expand Down Expand Up @@ -429,6 +429,16 @@ def run_sync(func, *args):

self.close(keepalive=True)

def redirect(self, url, code=None):
if code == 301:
message = 'Moved Permanently'
elif code == 303 and self.request.version != b'1.0':
message = 'See Other'
else:
return HTTPRedirect('', location=url)

return HTTPRedirect('', code=code, message=message, location=url)

async def handle_exception(self, exc, *args, data=None):
if self.request.protocol is None or self.request.transport is None:
return
Expand All @@ -437,7 +447,7 @@ async def handle_exception(self, exc, *args, data=None):
self.request.transport.abort()
return

if not isinstance(exc, asyncio.CancelledError):
if not isinstance(exc, (asyncio.CancelledError, HTTPRedirect)):
self.request.protocol.print_exception(
exc, *args, quote(unquote_to_bytes(self.request.path))
)
Expand Down Expand Up @@ -467,10 +477,10 @@ async def handle_exception(self, exc, *args, data=None):

data = str(exc)

if isinstance(exc, MethodNotAllowed):
self.set_header(b'Allow', b', '.join(exc.methods))
for name, value in exc.options.items():
self.set_header(name.replace('_', '-'), value)

if isinstance(data, str):
data = data.encode(exc.encoding)

await self.end(data, keepalive=False)
await self.end(data, keepalive=exc.code < 400)