Skip to content

Commit

Permalink
[PR #8445 backport][3.10] Fixes read_timeout on WS connection not res…
Browse files Browse the repository at this point in the history
…pecting ws_connect's timeouts #8444 (#8447)

Co-authored-by: J. Nick Koston <nick@koston.org>
  • Loading branch information
arcivanov and bdraco authored Jul 17, 2024
1 parent 266559c commit 776ebc6
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGES/8444.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``ws_connect`` not respecting `receive_timeout`` on WS(S) connection.
-- by :user:`arcivanov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Anes Abismail
Antoine Pietri
Anton Kasyanov
Anton Zhdan-Pushkin
Arcadiy Ivanov
Arseny Timoniq
Artem Yushkovskiy
Arthur Darcet
Expand Down
10 changes: 10 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,16 @@ async def _ws_connect(
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None

# For WS connection the read_timeout must be either receive_timeout or greater
# None == no timeout, i.e. infinite timeout, so None is the max timeout possible
if receive_timeout is None:
# Reset regardless
conn_proto.read_timeout = receive_timeout
elif conn_proto.read_timeout is not None:
# If read_timeout was set check which wins
conn_proto.read_timeout = max(receive_timeout, conn_proto.read_timeout)

transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
Expand Down
8 changes: 8 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ def _reschedule_timeout(self) -> None:
def start_timeout(self) -> None:
self._reschedule_timeout()

@property
def read_timeout(self) -> Optional[float]:
return self._read_timeout

@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout

def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
Expand Down
109 changes: 109 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -38,6 +39,97 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]


async def test_ws_connect_read_timeout_is_reset_to_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch(
"aiohttp.client.ClientSession.request"
) as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org", protocols=("t1", "t2", "chat")
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_stays_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os, mock.patch(
"aiohttp.client.ClientSession.request"
) as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
receive_timeout=0.5,
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_reset_to_max(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch(
"aiohttp.client.ClientSession.request"
) as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
receive_timeout=1.0,
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout == 1.0


async def test_ws_connect_with_origin(key_data, loop) -> None:
resp = mock.Mock()
resp.status = 403
Expand Down Expand Up @@ -68,6 +160,7 @@ async def test_ws_connect_with_params(ws_key, loop, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -93,6 +186,7 @@ def read(self, decode=False):
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -215,6 +309,7 @@ async def mock_get(*args, **kwargs):
hdrs.SEC_WEBSOCKET_ACCEPT: accept,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
return resp

with mock.patch("aiohttp.client.os") as m_os:
Expand Down Expand Up @@ -245,6 +340,7 @@ async def test_close(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -285,6 +381,7 @@ async def test_close_eofstream(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -315,6 +412,7 @@ async def test_close_exc(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -347,6 +445,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -381,6 +480,7 @@ async def test_send_data_after_close(ws_key, key_data, loop) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -409,6 +509,7 @@ async def test_send_data_type_errors(ws_key, key_data, loop) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -437,6 +538,7 @@ async def test_reader_read_exception(ws_key, key_data, loop) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
hresp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -501,6 +603,7 @@ async def test_ws_connect_non_overlapped_protocols(ws_key, loop, key_data) -> No
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -523,6 +626,7 @@ async def test_ws_connect_non_overlapped_protocols_2(ws_key, loop, key_data) ->
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -547,6 +651,7 @@ async def test_ws_connect_deflate(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -570,6 +675,7 @@ async def test_ws_connect_deflate_per_message(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -602,6 +708,7 @@ async def test_ws_connect_deflate_server_not_support(loop, ws_key, key_data) ->
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -626,6 +733,7 @@ async def test_ws_connect_deflate_notakeover(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_no_context_takeover",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -650,6 +758,7 @@ async def test_ws_connect_deflate_client_wbits(loop, ws_key, key_data) -> None:
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_max_window_bits=10",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down

0 comments on commit 776ebc6

Please sign in to comment.