diff --git a/aiohttp/client.py b/aiohttp/client.py index f76c619673f..9bdd51af39f 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -18,7 +18,7 @@ from .client_reqrep import ClientRequest, ClientResponse from .client_ws import ClientWebSocketResponse from .errors import WSServerHandshakeError -from .helpers import CookieJar +from .helpers import CookieJar, Timeout __all__ = ('ClientSession', 'request', 'get', 'options', 'head', 'delete', 'post', 'put', 'patch', 'ws_connect') @@ -106,7 +106,8 @@ def request(self, method, url, *, expect100=False, read_until_eof=True, proxy=None, - proxy_auth=None): + proxy_auth=None, + timeout=5*60): """Perform HTTP request.""" return _RequestContextManager( @@ -127,7 +128,8 @@ def request(self, method, url, *, expect100=expect100, read_until_eof=read_until_eof, proxy=proxy, - proxy_auth=proxy_auth,)) + proxy_auth=proxy_auth, + timeout=timeout)) @asyncio.coroutine def _request(self, method, url, *, @@ -145,7 +147,8 @@ def _request(self, method, url, *, expect100=False, read_until_eof=True, proxy=None, - proxy_auth=None): + proxy_auth=None, + timeout=5*60): if version is not None: warnings.warn("HTTP version should be specified " @@ -187,9 +190,10 @@ def _request(self, method, url, *, auth=auth, version=version, compress=compress, chunked=chunked, expect100=expect100, loop=self._loop, response_class=self._response_class, - proxy=proxy, proxy_auth=proxy_auth,) + proxy=proxy, proxy_auth=proxy_auth, timeout=timeout) - conn = yield from self._connector.connect(req) + with Timeout(timeout, loop=self._loop): + conn = yield from self._connector.connect(req) try: resp = req.send(conn.writer, conn.reader) try: diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 3c5a12c9099..57cb4569593 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -15,6 +15,7 @@ import aiohttp from . import hdrs, helpers, streams +from .helpers import Timeout from .log import client_logger from .multipart import MultipartWriter from .protocol import HttpMessage @@ -68,7 +69,8 @@ def __init__(self, method, url, *, version=aiohttp.HttpVersion11, compress=None, chunked=None, expect100=False, loop=None, response_class=None, - proxy=None, proxy_auth=None): + proxy=None, proxy_auth=None, + timeout=5*60): if loop is None: loop = asyncio.get_event_loop() @@ -80,6 +82,7 @@ def __init__(self, method, url, *, self.compress = compress self.loop = loop self.response_class = response_class or ClientResponse + self._timeout = timeout if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) @@ -502,7 +505,8 @@ def send(self, writer, reader): self.response = self.response_class( self.method, self.url, self.host, - writer=self._writer, continue100=self._continue) + writer=self._writer, continue100=self._continue, + timeout=self._timeout) self.response._post_init(self.loop) return self.response @@ -546,7 +550,8 @@ class ClientResponse: _loop = None _closed = True # to allow __del__ for non-initialized properly response - def __init__(self, method, url, host='', *, writer=None, continue100=None): + def __init__(self, method, url, host='', *, writer=None, continue100=None, + timeout=5*60): super().__init__() self.method = method @@ -558,6 +563,7 @@ def __init__(self, method, url, host='', *, writer=None, continue100=None): self._closed = False self._should_close = True # override by message.should_close later self._history = () + self._timeout = timeout def _post_init(self, loop): self._loop = loop @@ -609,7 +615,7 @@ def _setup_connection(self, connection): self._reader = connection.reader self._connection = connection self.content = self.flow_control_class( - connection.reader, loop=connection.loop) + connection.reader, loop=connection.loop, timeout=self._timeout) def _need_parse_response_body(self): return (self.method.lower() != 'head' and @@ -624,7 +630,8 @@ def start(self, connection, read_until_eof=False): httpstream = self._reader.set_parser(self._response_parser) # read response - message = yield from httpstream.read() + with Timeout(self._timeout, loop=self._loop): + message = yield from httpstream.read() if message.code != 100: break @@ -643,11 +650,11 @@ def start(self, connection, read_until_eof=False): self.raw_headers = tuple(message.raw_headers) # payload - response_with_body = self._need_parse_response_body() + rwb = self._need_parse_response_body() self._reader.set_parser( aiohttp.HttpPayloadParser(message, readall=read_until_eof, - response_with_body=response_with_body), + response_with_body=rwb), self.content) # cookies diff --git a/aiohttp/streams.py b/aiohttp/streams.py index d38eff4fc92..c5e6b395d3c 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -83,7 +83,7 @@ class StreamReader(asyncio.StreamReader, AsyncStreamReaderMixin): total_bytes = 0 - def __init__(self, limit=DEFAULT_LIMIT, loop=None): + def __init__(self, limit=DEFAULT_LIMIT, timeout=None, loop=None): self._limit = limit if loop is None: loop = asyncio.get_event_loop() @@ -93,8 +93,10 @@ def __init__(self, limit=DEFAULT_LIMIT, loop=None): self._buffer_offset = 0 self._eof = False self._waiter = None + self._canceller = None self._eof_waiter = None self._exception = None + self._timeout = timeout def __repr__(self): info = ['StreamReader'] @@ -122,6 +124,11 @@ def set_exception(self, exc): if not waiter.cancelled(): waiter.set_exception(exc) + canceller = self._canceller + if canceller is not None: + self._canceller = None + canceller.cancel() + def feed_eof(self): self._eof = True @@ -131,6 +138,11 @@ def feed_eof(self): if not waiter.cancelled(): waiter.set_result(True) + canceller = self._canceller + if canceller is not None: + self._canceller = None + canceller.cancel() + waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None @@ -185,7 +197,13 @@ def feed_data(self, data): if not waiter.cancelled(): waiter.set_result(False) - def _create_waiter(self, func_name): + canceller = self._canceller + if canceller is not None: + self._canceller = None + canceller.cancel() + + @asyncio.coroutine + def _wait(self, func_name): # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know @@ -193,7 +211,18 @@ def _create_waiter(self, func_name): if self._waiter is not None: raise RuntimeError('%s() called while another coroutine is ' 'already waiting for incoming data' % func_name) - return helpers.create_future(self._loop) + waiter = self._waiter = helpers.create_future(self._loop) + if self._timeout: + self._canceller = self._loop.call_later(self._timeout, + self.set_exception, + asyncio.TimeoutError()) + try: + yield from waiter + finally: + self._waiter = None + if self._canceller is not None: + self._canceller.cancel() + self._canceller = None @asyncio.coroutine def readline(self): @@ -222,11 +251,7 @@ def readline(self): break if not_enough: - self._waiter = self._create_waiter('readline') - try: - yield from self._waiter - finally: - self._waiter = None + yield from self._wait('readline') return b''.join(line) @@ -265,11 +290,7 @@ def read(self, n=-1): return b''.join(blocks) if not self._buffer and not self._eof: - self._waiter = self._create_waiter('read') - try: - yield from self._waiter - finally: - self._waiter = None + yield from self._wait('read') return self._read_nowait(n) @@ -279,11 +300,7 @@ def readany(self): raise self._exception if not self._buffer and not self._eof: - self._waiter = self._create_waiter('readany') - try: - yield from self._waiter - finally: - self._waiter = None + yield from self._wait('readany') return self._read_nowait() diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 0cba3ba6f2e..bd426488b0c 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -474,3 +474,37 @@ def handler(request): resp = yield from client.delete('/') assert resp.status == 204 yield from resp.release() + + +@pytest.mark.run_loop +def test_timeout_on_reading_headers(create_app_and_client, loop): + + @asyncio.coroutine + def handler(request): + resp = web.StreamResponse() + yield from asyncio.sleep(0.1, loop=loop) + yield from resp.prepare(request) + return resp + + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) + with pytest.raises(asyncio.TimeoutError): + yield from client.get('/', timeout=0.01) + + +@pytest.mark.run_loop +def test_timeout_on_reading_data(create_app_and_client, loop): + + @asyncio.coroutine + def handler(request): + resp = web.StreamResponse() + yield from resp.prepare(request) + yield from asyncio.sleep(0.1, loop=loop) + return resp + + app, client = yield from create_app_and_client() + app.router.add_route('GET', '/', handler) + resp = yield from client.get('/', timeout=0.05) + + with pytest.raises(asyncio.TimeoutError): + yield from resp.read() diff --git a/tests/test_client_response.py b/tests/test_client_response.py index a03a7134eee..c60cb469fb6 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -244,11 +244,11 @@ def side_effect(*args, **kwargs): def test_override_flow_control(self): class MyResponse(ClientResponse): - flow_control_class = aiohttp.FlowControlDataQueue + flow_control_class = aiohttp.StreamReader response = MyResponse('get', 'http://my-cl-resp.org') response._post_init(self.loop) response._setup_connection(self.connection) - self.assertIsInstance(response.content, aiohttp.FlowControlDataQueue) + self.assertIsInstance(response.content, aiohttp.StreamReader) response.close() @mock.patch('aiohttp.client_reqrep.chardet') diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 5891bbe8dc1..deb22011088 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -363,7 +363,9 @@ def create_connection(req): assert e.strerror == err.strerror +@pytest.mark.run_loop def test_request_ctx_manager_props(loop): + yield from asyncio.sleep(0, loop=loop) # to make it a task with aiohttp.ClientSession(loop=loop) as client: ctx_mgr = client.get('http://example.com') diff --git a/tests/test_streams.py b/tests/test_streams.py index e401248c171..dbee463cf6f 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -24,7 +24,8 @@ def _make_one(self, *args, **kwargs): def test_create_waiter(self): stream = self._make_one() stream._waiter = helpers.create_future(self.loop) - self.assertRaises(RuntimeError, stream._create_waiter, 'test') + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(stream._wait('test')) @mock.patch('aiohttp.streams.asyncio') def test_ctor_global_loop(self, m_asyncio): @@ -518,7 +519,7 @@ def test_read_nowait_exception(self): def test_read_nowait_waiter(self): stream = self._make_one() stream.feed_data(b'line\n') - stream._waiter = stream._create_waiter('readany') + stream._waiter = helpers.create_future(self.loop) self.assertRaises(RuntimeError, stream.read_nowait) @@ -548,7 +549,7 @@ def test___repr__exception(self): def test___repr__waiter(self): stream = self._make_one() - stream._waiter = stream._create_waiter('test_waiter') + stream._waiter = helpers.create_future(self.loop) self.assertRegex( repr(stream), ">") @@ -557,6 +558,66 @@ def test___repr__waiter(self): stream._waiter = None self.assertEqual("", repr(stream)) + def test_unread_empty(self): + stream = self._make_one() + stream.feed_data(b'line1') + stream.feed_eof() + stream.unread_data(b'') + + data = self.loop.run_until_complete(stream.read(5)) + self.assertEqual(b'line1', data) + self.assertTrue(stream.at_eof()) + + def test_set_exception_cancels_timeout(self): + stream = self._make_one(timeout=1) + task = helpers.ensure_future(stream.readany(), loop=self.loop) + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + + self.assertIsNotNone(stream._canceller) + canceller = stream._canceller = mock.Mock() + stream.set_exception(ValueError()) + self.assertIsNone(stream._canceller) + canceller.cancel.assert_called_with() + self.assertRaises( + ValueError, self.loop.run_until_complete, task) + + def test_feed_eof_cancels_timeout(self): + stream = self._make_one(timeout=1) + task = helpers.ensure_future(stream.readany(), loop=self.loop) + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + + self.assertIsNotNone(stream._canceller) + canceller = stream._canceller = mock.Mock() + stream.feed_eof() + self.assertIsNone(stream._canceller) + canceller.cancel.assert_called_with() + self.assertEqual(b'', self.loop.run_until_complete(task)) + + def test_feed_data_cancels_timeout(self): + stream = self._make_one(timeout=1) + task = helpers.ensure_future(stream.readany(), loop=self.loop) + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + + self.assertIsNotNone(stream._canceller) + canceller = stream._canceller = mock.Mock() + stream.feed_data(b'data') + self.assertIsNone(stream._canceller) + canceller.cancel.assert_called_with() + self.assertEqual(b'data', self.loop.run_until_complete(task)) + + def test_wait_cancels_timeout(self): + # Read bytes. + stream = self._make_one(timeout=1) + task = helpers.ensure_future(stream._wait('test'), loop=self.loop) + self.loop.run_until_complete(asyncio.sleep(0, loop=self.loop)) + + self.assertIsNotNone(stream._canceller) + canceller = stream._canceller = mock.Mock() + stream._waiter.set_result(None) + self.loop.run_until_complete(task) + self.assertIsNone(stream._canceller) + canceller.cancel.assert_called_with() + class TestEmptyStreamReader(unittest.TestCase): diff --git a/tests/test_timeout.py b/tests/test_timeout.py index 70b6205c370..a4a0c3af44a 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -63,7 +63,7 @@ def long_running_task(): resp = yield from long_running_task() assert resp == 'done' dt = loop.time() - t0 - assert 0.09 < dt < 0.12, dt + assert 0.09 < dt < 0.13, dt @pytest.mark.run_loop