Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add wsgi support to GunicornWebWorker #1418

Merged
merged 2 commits into from
Nov 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions aiohttp/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Async gunicorn worker for aiohttp.web"""

import asyncio
import logging
import os
import re
import signal
Expand All @@ -12,6 +13,8 @@
from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat

from aiohttp.helpers import AccessLogger, ensure_future
from aiohttp.web_server import WebServer
from aiohttp.wsgi import WSGIServerHttpProtocol

__all__ = ('GunicornWebWorker', 'GunicornUVLoopWebWorker')

Expand All @@ -37,7 +40,8 @@ def init_process(self):
super().init_process()

def run(self):
self.loop.run_until_complete(self.wsgi.startup())
if hasattr(self.wsgi, 'startup'):
self.loop.run_until_complete(self.wsgi.startup())
self._runner = ensure_future(self._run(), loop=self.loop)

try:
Expand All @@ -48,13 +52,16 @@ def run(self):
sys.exit(self.exit_code)

def make_handler(self, app):
return app.make_handler(
logger=self.log,
slow_request_timeout=self.cfg.timeout,
keepalive_timeout=self.cfg.keepalive,
access_log=self.log.access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
if hasattr(self.wsgi, 'make_handler'):
return app.make_handler(
logger=self.log,
slow_request_timeout=self.cfg.timeout,
keepalive_timeout=self.cfg.keepalive,
access_log=self.log.access_log,
access_log_format=self._get_valid_log_format(
self.cfg.access_log_format))
else:
return WSGIServer(self.wsgi, self)

@asyncio.coroutine
def close(self):
Expand All @@ -70,7 +77,8 @@ def close(self):
yield from server.wait_closed()

# send on_shutdown event
yield from self.wsgi.shutdown()
if hasattr(self.wsgi, 'shutdown'):
yield from self.wsgi.shutdown()

# stop alive connections
tasks = [
Expand All @@ -80,7 +88,8 @@ def close(self):
yield from asyncio.gather(*tasks, loop=self.loop)

# cleanup application
yield from self.wsgi.cleanup()
if hasattr(self.wsgi, 'cleanup'):
yield from self.wsgi.cleanup()

@asyncio.coroutine
def _run(self):
Expand Down Expand Up @@ -184,6 +193,26 @@ def _get_valid_log_format(self, source_format):
return source_format


class WSGIServer(WebServer):

def __init__(self, app, worker):
super().__init__(app, loop=worker.loop)

self.worker = worker
self.access_log_format = worker._get_valid_log_format(
worker.cfg.access_log_format)

def __call__(self):
return WSGIServerHttpProtocol(
self.handler, readpayload=True,
loop=self._loop,
logger=self.worker.log,
debug=self.worker.log.loglevel == logging.DEBUG,
keep_alive=self.worker.cfg.keepalive,
access_log=self.worker.log.access_log,
access_log_format=self.access_log_format)


class GunicornUVLoopWebWorker(GunicornWebWorker):

def init_process(self):
Expand Down
47 changes: 47 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ def test_run(worker, loop):
assert loop.is_closed()


def test_run_wsgi(worker, loop):
worker.wsgi = lambda env, start_resp: start_resp()

worker.loop = loop
worker._run = mock.Mock(
wraps=asyncio.coroutine(lambda: None))
with pytest.raises(SystemExit):
worker.run()
assert worker._run.called
assert loop.is_closed()


def test_handle_quit(worker):
worker.handle_quit(object(), object())
assert not worker.alive
Expand Down Expand Up @@ -108,6 +120,20 @@ def test_make_handler(worker, mocker):
assert worker._get_valid_log_format.called


def test_make_handler_wsgi(worker, mocker):
worker.wsgi = lambda env, start_resp: start_resp()
worker.loop = mock.Mock()
worker.log = mock.Mock()
worker.cfg = mock.Mock()
worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT
mocker.spy(worker, '_get_valid_log_format')

f = worker.make_handler(worker.wsgi)
assert isinstance(f, base_worker.WSGIServer)
assert isinstance(f(), base_worker.WSGIServerHttpProtocol)
assert worker._get_valid_log_format.called


@pytest.mark.parametrize('source,result', [
(ACCEPTABLE_LOG_FORMAT, ACCEPTABLE_LOG_FORMAT),
(AsyncioWorker.DEFAULT_GUNICORN_LOG_FORMAT,
Expand Down Expand Up @@ -254,6 +280,27 @@ def test_close(worker, loop):
yield from worker.close()


@asyncio.coroutine
def test_close_wsgi(worker, loop):
srv = mock.Mock()
srv.wait_closed = make_mocked_coro(None)
handler = mock.Mock()
worker.servers = {srv: handler}
worker.log = mock.Mock()
worker.loop = loop
worker.wsgi = lambda env, start_resp: start_resp()
handler.connections = [object()]
handler.shutdown.return_value = helpers.create_future(loop)
handler.shutdown.return_value.set_result(1)

yield from worker.close()
handler.shutdown.assert_called_with(timeout=95.0)
srv.close.assert_called_with()
assert worker.servers is None

yield from worker.close()


@asyncio.coroutine
def test__run_ok_no_max_requests(worker, loop):
worker.ppid = 1
Expand Down