Skip to content

Commit

Permalink
ServerHttpProtocol refactoring (#1060)
Browse files Browse the repository at this point in the history
* Drop _slow_request_timeout_handle

* Drop unused method

* Refactor ServerHttpProtocol.start()

* Drop always true conditions

* Replace a test with more obvious version

* Fix docstring

* Rename private attrs

* Refactor finalizing code
  • Loading branch information
asvetlov authored Aug 19, 2016
1 parent 4599080 commit 81c00e4
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 277 deletions.
6 changes: 1 addition & 5 deletions aiohttp/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,7 @@ def __init__(self, line=''):
self.line = line


class ParserError(Exception):
"""Base parser error."""


class LineLimitExceededParserError(ParserError):
class LineLimitExceededParserError(HttpBadRequest):
"""Line is too long."""

def __init__(self, msg, limit):
Expand Down
214 changes: 75 additions & 139 deletions aiohttp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import socket
import traceback
import warnings
from contextlib import suppress
from html import escape as html_escape
from math import ceil

import aiohttp
from aiohttp import errors, hdrs, helpers, streams
from aiohttp.helpers import _get_kwarg, ensure_future
from aiohttp.helpers import Timeout, _get_kwarg, ensure_future
from aiohttp.log import access_logger, server_logger

__all__ = ('ServerHttpProtocol',)
Expand Down Expand Up @@ -53,15 +53,11 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):
:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive: int or None
:type keepalive_timeout: int or None
:param bool tcp_keepalive: TCP keep-alive is on, default is on
:param int timeout: slow request timeout
:param allowed_methods: (optional) List of allowed request methods.
Set to empty list to allow all methods.
:type allowed_methods: tuple
:param int slow_request_timeout: slow request timeout
:param bool debug: enable debug mode
Expand All @@ -85,9 +81,7 @@ class ServerHttpProtocol(aiohttp.StreamProtocol):
_request_count = 0
_request_handler = None
_reading_request = False
_keep_alive = False # keep transport open
_keep_alive_handle = None # keep alive timer handle
_slow_request_timeout_handle = None # slow request timer handle
_keepalive = False # keep transport open

def __init__(self, *, loop=None,
keepalive_timeout=75, # NGINX default value is 75 secs
Expand Down Expand Up @@ -138,6 +132,7 @@ def __init__(self, *, loop=None,
access_log_format)
else:
self.access_logger = None
self._closing = False

@property
def keep_alive_timeout(self):
Expand All @@ -150,57 +145,38 @@ def keep_alive_timeout(self):
def keepalive_timeout(self):
return self._keepalive_timeout

def closing(self, timeout=15.0):
@asyncio.coroutine
def shutdown(self, timeout=15.0):
"""Worker process is about to exit, we need cleanup everything and
stop accepting requests. It is especially important for keep-alive
connections."""
self._keep_alive = False
self._tcp_keep_alive = False
self._keepalive_timeout = None

if (not self._reading_request and self.transport is not None):
if self._request_handler:
self._request_handler.cancel()
self._request_handler = None

self.transport.close()
self.transport = None
elif self.transport is not None and timeout:
if self._slow_request_timeout_handle is not None:
self._slow_request_timeout_handle.cancel()

# use slow request timeout for closing
# connection_lost cleans timeout handler
now = self._loop.time()
self._slow_request_timeout_handle = self._loop.call_at(
ceil(now+timeout), self.cancel_slow_request)
if self._request_handler is None:
return
self._closing = True

if timeout:
canceller = self._loop.call_later(timeout,
self._request_handler.cancel)
with suppress(asyncio.CancelledError):
yield from self._request_handler
canceller.cancel()
else:
self._request_handler.cancel()

def connection_made(self, transport):
super().connection_made(transport)

self._request_handler = ensure_future(self.start(), loop=self._loop)

# start slow request timer
if self._slow_request_timeout:
now = self._loop.time()
self._slow_request_timeout_handle = self._loop.call_at(
ceil(now+self._slow_request_timeout), self.cancel_slow_request)

if self._tcp_keepalive:
tcp_keepalive(self, transport)

def connection_lost(self, exc):
super().connection_lost(exc)

self._closing = True
if self._request_handler is not None:
self._request_handler.cancel()
self._request_handler = None
if self._keep_alive_handle is not None:
self._keep_alive_handle.cancel()
self._keep_alive_handle = None
if self._slow_request_timeout_handle is not None:
self._slow_request_timeout_handle.cancel()
self._slow_request_timeout_handle = None

def data_received(self, data):
super().data_received(data)
Expand All @@ -209,17 +185,12 @@ def data_received(self, data):
if not self._reading_request:
self._reading_request = True

# stop keep-alive timer
if self._keep_alive_handle is not None:
self._keep_alive_handle.cancel()
self._keep_alive_handle = None

def keep_alive(self, val):
"""Set keep-alive connection mode.
:param bool val: new state.
"""
self._keep_alive = val
self._keepalive = val

def log_access(self, message, environ, response, time):
if self.access_logger:
Expand All @@ -233,16 +204,6 @@ def log_debug(self, *args, **kw):
def log_exception(self, *args, **kw):
self.logger.exception(*args, **kw)

def cancel_slow_request(self):
if self._request_handler is not None:
self._request_handler.cancel()
self._request_handler = None

if self.transport is not None:
self.transport.close()

self.log_debug('Close slow request.')

@asyncio.coroutine
def start(self):
"""Start processing of incoming requests.
Expand All @@ -255,44 +216,35 @@ def start(self):
"""
reader = self.reader

while True:
message = None
self._keep_alive = False
self._request_count += 1
self._reading_request = False

payload = None
try:
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()

# start reading request
self._reading_request = True

# start slow request timer
if (self._slow_request_timeout and
self._slow_request_timeout_handle is None):
now = self._loop.time()
self._slow_request_timeout_handle = self._loop.call_at(
ceil(now+self._slow_request_timeout),
self.cancel_slow_request)

# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()

# cancel slow request timer
if self._slow_request_timeout_handle is not None:
self._slow_request_timeout_handle.cancel()
self._slow_request_timeout_handle = None
try:
while not self._closing:
message = None
self._keepalive = False
self._request_count += 1
self._reading_request = False

payload = None
with Timeout(max(self._slow_request_timeout,
self._keepalive_timeout),
loop=self._loop):
# read HTTP request method
prefix = reader.set_parser(self._request_prefix)
yield from prefix.read()

# start reading request
self._reading_request = True

# start slow request timer
# read request headers
httpstream = reader.set_parser(self._request_parser)
message = yield from httpstream.read()

# request may not have payload
try:
content_length = int(
message.headers.get(hdrs.CONTENT_LENGTH, 0))
except ValueError:
content_length = 0
raise errors.InvalidHeader(hdrs.CONTENT_LENGTH) from None

if (content_length > 0 or
message.method == 'CONNECT' or
Expand All @@ -308,55 +260,39 @@ def start(self):

yield from self.handle_request(message, payload)

except asyncio.CancelledError:
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
if self.transport is not None:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except errors.LineLimitExceededParserError as exc:
yield from self.handle_error(400, message, None, exc)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
return

if payload and not payload.is_eof():
self.log_debug('Uncompleted request.')
self._request_handler = None
self.transport.close()
return
self._closing = True
else:
reader.unset_parser()

if self._request_handler:
if self._keep_alive and self._keepalive_timeout:
self.log_debug(
'Start keep-alive timer for %s sec.',
self._keepalive_timeout)
now = self._loop.time()
self._keep_alive_handle = self._loop.call_at(
ceil(now+self._keepalive_timeout),
self.transport.close)
elif self._keep_alive:
# do nothing, rely on kernel or upstream server
pass
else:
self.log_debug('Close client connection.')
self._request_handler = None
self.transport.close()
return
else:
# connection is closed
return
if not self._keepalive or not self._keepalive_timeout:
self._closing = True

except asyncio.CancelledError:
self.log_debug(
'Request handler cancelled.')
return
except asyncio.TimeoutError:
self.log_debug(
'Request handler timed out.')
return
except errors.ClientDisconnectedError:
self.log_debug(
'Ignored premature client disconnection #1.')
return
except errors.HttpProcessingError as exc:
yield from self.handle_error(exc.code, message,
None, exc, exc.headers,
exc.message)
except Exception as exc:
yield from self.handle_error(500, message, None, exc)
finally:
self._request_handler = None
if self.transport is None:
self.log_debug(
'Ignored premature client disconnection #2.')
else:
self.transport.close()

def handle_error(self, status=500, message=None,
payload=None, exc=None, headers=None, reason=None):
Expand All @@ -366,7 +302,7 @@ def handle_error(self, status=500, message=None,
information. It always closes current connection."""
now = self._loop.time()
try:
if self._request_handler is None:
if self.transport is None:
# client has been disconnected during writing.
return ()

Expand Down
30 changes: 2 additions & 28 deletions aiohttp/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,10 @@ def connection_lost(self, handler, exc=None):
if handler in self._connections:
del self._connections[handler]

@asyncio.coroutine
def _connections_cleanup(self):
sleep = 0.05
while self._connections:
yield from asyncio.sleep(sleep, loop=self._loop)
if sleep < 5:
sleep = sleep * 2

@asyncio.coroutine
def finish_connections(self, timeout=None):
# try to close connections in 90% of graceful timeout
timeout90 = None
if timeout:
timeout90 = timeout / 100 * 90

for handler in self._connections.keys():
handler.closing(timeout=timeout90)

if timeout:
try:
yield from asyncio.wait_for(
self._connections_cleanup(), timeout, loop=self._loop)
except asyncio.TimeoutError:
self._app.logger.warning(
"Not all connections are closed (pending: %d)",
len(self._connections))

for transport in self._connections.values():
transport.close()

coros = [conn.shutdown(timeout) for conn in self._connections]
yield from asyncio.gather(*coros, loop=self._loop)
self._connections.clear()

def __call__(self):
Expand Down
Loading

0 comments on commit 81c00e4

Please sign in to comment.