Skip to content

Commit

Permalink
Fixes socket timeout on WS connection not respecting ws_connect's tim…
Browse files Browse the repository at this point in the history
…eouts

Added read_timeout property to ResponseHandler to allow override

After WS(S) connection is established, adjust `conn.proto.read_timeout` to
be the largest of the `read_timeout` and the `ws_connect`'s
`timeout` or `receive_timeout`, whichever are specified.

fixes aio-libs#8444
  • Loading branch information
arcivanov committed Jun 8, 2024
1 parent f662958 commit 43b6adc
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES/8444.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``ws_connect`` not respecting ``ws_receive`` timeout for WS(S) connection.
-- by :user:`arcivanov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Anes Abismail
Antoine Pietri
Anton Kasyanov
Anton Zhdan-Pushkin
Arcadiy Ivanov
Arie Bovenberg
Arseny Timoniq
Artem Yushkovskiy
Expand Down
11 changes: 11 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,17 @@ async def _ws_connect(
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None

# For WS connection the read_timeout must be either ws_timeout.ws_receive or greater
# None == no timeout, i.e. infinite timeout, so None is the max timeout possible
if ws_timeout.ws_receive is None:
# Reset regardless
conn_proto.read_timeout = None
elif conn_proto.read_timeout is not None:
conn_proto.read_timeout = max(
ws_timeout.ws_receive, conn_proto.read_timeout
)

transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
Expand Down
8 changes: 8 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def _reschedule_timeout(self) -> None:
def start_timeout(self) -> None:
self._reschedule_timeout()

@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout

@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout

def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
Expand Down
104 changes: 104 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import aiohttp
from aiohttp import client, hdrs
from aiohttp.client_ws import ClientWSTimeout
from aiohttp.http import WS_KEY
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -39,6 +40,7 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -54,6 +56,91 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]


async def test_ws_connect_read_timeout_is_reset_to_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org", protocols=("t1", "t2", "chat")
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_stays_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
timeout=ClientWSTimeout(0.5),
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_reset_to_max(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
timeout=ClientWSTimeout(1.0),
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout == 1.0


async def test_ws_connect_with_origin(key_data: Any, loop: Any) -> None:
resp = mock.Mock()
resp.status = 403
Expand Down Expand Up @@ -84,6 +171,7 @@ async def test_ws_connect_with_params(ws_key: Any, loop: Any, key_data: Any) ->
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -111,6 +199,7 @@ def read(self, decode=False):
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -233,6 +322,7 @@ async def mock_get(*args, **kwargs):
hdrs.SEC_WEBSOCKET_ACCEPT: accept,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
return resp

with mock.patch("aiohttp.client.os") as m_os:
Expand Down Expand Up @@ -263,6 +353,7 @@ async def test_close(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -303,6 +394,7 @@ async def test_close_eofstream(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -333,6 +425,7 @@ async def test_close_exc(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -365,6 +458,7 @@ async def test_close_exc2(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -397,6 +491,7 @@ async def test_send_data_after_close(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -425,6 +520,7 @@ async def test_send_data_type_errors(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand All @@ -451,6 +547,7 @@ async def test_reader_read_exception(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
hresp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -519,6 +616,7 @@ async def test_ws_connect_non_overlapped_protocols(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -543,6 +641,7 @@ async def test_ws_connect_non_overlapped_protocols_2(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -567,6 +666,7 @@ async def test_ws_connect_deflate(loop: Any, ws_key: Any, key_data: Any) -> None
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -592,6 +692,7 @@ async def test_ws_connect_deflate_per_message(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -626,6 +727,7 @@ async def test_ws_connect_deflate_server_not_support(
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -652,6 +754,7 @@ async def test_ws_connect_deflate_notakeover(
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_no_context_takeover",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -678,6 +781,7 @@ async def test_ws_connect_deflate_client_wbits(
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_max_window_bits=10",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down

0 comments on commit 43b6adc

Please sign in to comment.