Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes read_timeout on WS connection not respecting ws_connect's timeouts #8445

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/8444.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix ``ws_connect`` not respecting ``ws_receive`` timeout for WS(S) connection.
-- by :user:`arcivanov`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Anes Abismail
Antoine Pietri
Anton Kasyanov
Anton Zhdan-Pushkin
Arcadiy Ivanov
Arie Bovenberg
Arseny Timoniq
Artem Yushkovskiy
Expand Down
11 changes: 11 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,17 @@ async def _ws_connect(
assert conn is not None
conn_proto = conn.protocol
assert conn_proto is not None

# For WS connection the read_timeout must be either ws_timeout.ws_receive or greater
# None == no timeout, i.e. infinite timeout, so None is the max timeout possible
if ws_timeout.ws_receive is None:
# Reset regardless
conn_proto.read_timeout = None
elif conn_proto.read_timeout is not None:
conn_proto.read_timeout = max(
ws_timeout.ws_receive, conn_proto.read_timeout
)

Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
transport = conn.transport
assert transport is not None
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
Expand Down
8 changes: 8 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ def _reschedule_timeout(self) -> None:
def start_timeout(self) -> None:
self._reschedule_timeout()

@property
bdraco marked this conversation as resolved.
Show resolved Hide resolved
def read_timeout(self) -> Optional[float]:
return self._read_timeout

@read_timeout.setter
def read_timeout(self, read_timeout: Optional[float]) -> None:
self._read_timeout = read_timeout

def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
Expand Down
110 changes: 110 additions & 0 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import aiohttp
from aiohttp import client, hdrs
from aiohttp.client_ws import ClientWSTimeout
from aiohttp.http import WS_KEY
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -39,6 +40,7 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -54,6 +56,97 @@ async def test_ws_connect(ws_key: Any, loop: Any, key_data: Any) -> None:
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]


async def test_ws_connect_read_timeout_is_reset_to_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch(
"aiohttp.client.ClientSession.request"
) as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org", protocols=("t1", "t2", "chat")
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_stays_inf(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os, mock.patch(
"aiohttp.client.ClientSession.request"
) as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
timeout=ClientWSTimeout(0.5),
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout is None


async def test_ws_connect_read_timeout_reset_to_max(
ws_key: Any, loop: Any, key_data: Any
) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
hdrs.UPGRADE: "websocket",
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = 0.5
with mock.patch("aiohttp.client.os") as m_os, mock.patch(
"aiohttp.client.ClientSession.request"
) as m_req:
m_os.urandom.return_value = key_data
m_req.return_value = loop.create_future()
m_req.return_value.set_result(resp)

res = await aiohttp.ClientSession().ws_connect(
"http://test.org",
protocols=("t1", "t2", "chat"),
timeout=ClientWSTimeout(1.0),
)

assert isinstance(res, client.ClientWebSocketResponse)
assert res.protocol == "chat"
assert hdrs.ORIGIN not in m_req.call_args[1]["headers"]
assert resp.connection.protocol.read_timeout == 1.0


async def test_ws_connect_with_origin(key_data: Any, loop: Any) -> None:
resp = mock.Mock()
resp.status = 403
Expand Down Expand Up @@ -84,6 +177,7 @@ async def test_ws_connect_with_params(ws_key: Any, loop: Any, key_data: Any) ->
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
bdraco marked this conversation as resolved.
Show resolved Hide resolved
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -111,6 +205,7 @@ def read(self, decode=False):
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -233,6 +328,7 @@ async def mock_get(*args, **kwargs):
hdrs.SEC_WEBSOCKET_ACCEPT: accept,
hdrs.SEC_WEBSOCKET_PROTOCOL: "chat",
}
resp.connection.protocol.read_timeout = None
return resp

with mock.patch("aiohttp.client.os") as m_os:
Expand Down Expand Up @@ -263,6 +359,7 @@ async def test_close(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -303,6 +400,7 @@ async def test_close_eofstream(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -333,6 +431,7 @@ async def test_close_exc(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -365,6 +464,7 @@ async def test_close_exc2(loop: Any, ws_key: Any, key_data: Any) -> None:
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -397,6 +497,7 @@ async def test_send_data_after_close(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down Expand Up @@ -425,6 +526,7 @@ async def test_send_data_type_errors(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand All @@ -451,6 +553,7 @@ async def test_reader_read_exception(ws_key: Any, key_data: Any, loop: Any) -> N
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
hresp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -519,6 +622,7 @@ async def test_ws_connect_non_overlapped_protocols(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -543,6 +647,7 @@ async def test_ws_connect_non_overlapped_protocols_2(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_PROTOCOL: "other,another",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -567,6 +672,7 @@ async def test_ws_connect_deflate(loop: Any, ws_key: Any, key_data: Any) -> None
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -592,6 +698,7 @@ async def test_ws_connect_deflate_per_message(
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.WebSocketWriter") as WebSocketWriter:
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
Expand Down Expand Up @@ -626,6 +733,7 @@ async def test_ws_connect_deflate_server_not_support(
hdrs.CONNECTION: "upgrade",
hdrs.SEC_WEBSOCKET_ACCEPT: ws_key,
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -652,6 +760,7 @@ async def test_ws_connect_deflate_notakeover(
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_no_context_takeover",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand All @@ -678,6 +787,7 @@ async def test_ws_connect_deflate_client_wbits(
hdrs.SEC_WEBSOCKET_EXTENSIONS: "permessage-deflate; "
"client_max_window_bits=10",
}
resp.connection.protocol.read_timeout = None
with mock.patch("aiohttp.client.os") as m_os:
with mock.patch("aiohttp.client.ClientSession.request") as m_req:
m_os.urandom.return_value = key_data
Expand Down
Loading