Skip to content

Commit 25c7f23

Browse files
bdracowebknjaz
andauthored
Restore zero copy writes on Python 3.12.9+/3.13.2+ (#10137)
Co-authored-by: 🇺🇦 Sviatoslav Sydorenko (Святослав Сидоренко) <sviat@redhat.com>
1 parent 95b28c7 commit 25c7f23

File tree

4 files changed

+129
-4
lines changed

4 files changed

+129
-4
lines changed

.github/workflows/ci-cd.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,11 @@ jobs:
266266
uses: actions/checkout@v4
267267
with:
268268
submodules: true
269-
- name: Setup Python 3.13
269+
- name: Setup Python 3.13.2
270270
id: python-install
271271
uses: actions/setup-python@v5
272272
with:
273-
python-version: 3.13
273+
python-version: 3.13.2
274274
cache: pip
275275
cache-dependency-path: requirements/*.txt
276276
- name: Update pip, wheel, setuptools, build, twine

CHANGES/10137.misc.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Restored support for zero copy writes when using Python 3.12 versions 3.12.9 and later or Python 3.13.2+ -- by :user:`bdraco`.
2+
3+
Zero copy writes were previously disabled due to :cve:`2024-12254` which is resolved in these Python versions.

aiohttp/http_writer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Http related parsers and protocol."""
22

33
import asyncio
4+
import sys
45
import zlib
56
from typing import ( # noqa
67
Any,
@@ -24,6 +25,17 @@
2425
__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
2526

2627

28+
MIN_PAYLOAD_FOR_WRITELINES = 2048
29+
IS_PY313_BEFORE_313_2 = (3, 13, 0) <= sys.version_info < (3, 13, 2)
30+
IS_PY_BEFORE_312_9 = sys.version_info < (3, 12, 9)
31+
SKIP_WRITELINES = IS_PY313_BEFORE_313_2 or IS_PY_BEFORE_312_9
32+
# writelines is not safe for use
33+
# on Python 3.12+ until 3.12.9
34+
# on Python 3.13+ until 3.13.2
35+
# and on older versions it not any faster than write
36+
# CVE-2024-12254: https://github.com/python/cpython/pull/127656
37+
38+
2739
class HttpVersion(NamedTuple):
2840
major: int
2941
minor: int
@@ -90,7 +102,10 @@ def _writelines(self, chunks: Iterable[bytes]) -> None:
90102
transport = self._protocol.transport
91103
if transport is None or transport.is_closing():
92104
raise ClientConnectionResetError("Cannot write to closing transport")
93-
transport.write(b"".join(chunks))
105+
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
106+
transport.write(b"".join(chunks))
107+
else:
108+
transport.writelines(chunks)
94109

95110
async def write(
96111
self,

tests/test_http_writer.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import array
33
import asyncio
44
import zlib
5-
from typing import Any, Iterable
5+
from typing import Any, Generator, Iterable
66
from unittest import mock
77

88
import pytest
@@ -13,6 +13,18 @@
1313
from aiohttp.test_utils import make_mocked_coro
1414

1515

16+
@pytest.fixture
17+
def enable_writelines() -> Generator[None, None, None]:
18+
with mock.patch("aiohttp.http_writer.SKIP_WRITELINES", False):
19+
yield
20+
21+
22+
@pytest.fixture
23+
def force_writelines_small_payloads() -> Generator[None, None, None]:
24+
with mock.patch("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES", 1):
25+
yield
26+
27+
1628
@pytest.fixture
1729
def buf() -> bytearray:
1830
return bytearray()
@@ -136,6 +148,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof(
136148
assert zlib.decompress(content) == (b"data" * 4096) + payload
137149

138150

151+
@pytest.mark.usefixtures("enable_writelines")
152+
async def test_write_large_payload_deflate_compression_data_in_eof_writelines(
153+
protocol: BaseProtocol,
154+
transport: asyncio.Transport,
155+
loop: asyncio.AbstractEventLoop,
156+
) -> None:
157+
msg = http.StreamWriter(protocol, loop)
158+
msg.enable_compression("deflate")
159+
160+
await msg.write(b"data" * 4096)
161+
assert transport.write.called # type: ignore[attr-defined]
162+
chunks = [c[1][0] for c in list(transport.write.mock_calls)] # type: ignore[attr-defined]
163+
transport.write.reset_mock() # type: ignore[attr-defined]
164+
assert not transport.writelines.called # type: ignore[attr-defined]
165+
166+
# This payload compresses to 20447 bytes
167+
payload = b"".join(
168+
[bytes((*range(0, i), *range(i, 0, -1))) for i in range(255) for _ in range(64)]
169+
)
170+
await msg.write_eof(payload)
171+
assert not transport.write.called # type: ignore[attr-defined]
172+
assert transport.writelines.called # type: ignore[attr-defined]
173+
chunks.extend(transport.writelines.mock_calls[0][1][0]) # type: ignore[attr-defined]
174+
content = b"".join(chunks)
175+
assert zlib.decompress(content) == (b"data" * 4096) + payload
176+
177+
139178
async def test_write_payload_chunked_filter(
140179
protocol: BaseProtocol,
141180
transport: asyncio.Transport,
@@ -207,6 +246,26 @@ async def test_write_payload_deflate_compression_chunked(
207246
assert content == expected
208247

209248

249+
@pytest.mark.usefixtures("enable_writelines")
250+
@pytest.mark.usefixtures("force_writelines_small_payloads")
251+
async def test_write_payload_deflate_compression_chunked_writelines(
252+
protocol: BaseProtocol,
253+
transport: asyncio.Transport,
254+
loop: asyncio.AbstractEventLoop,
255+
) -> None:
256+
expected = b"2\r\nx\x9c\r\na\r\nKI,I\x04\x00\x04\x00\x01\x9b\r\n0\r\n\r\n"
257+
msg = http.StreamWriter(protocol, loop)
258+
msg.enable_compression("deflate")
259+
msg.enable_chunking()
260+
await msg.write(b"data")
261+
await msg.write_eof()
262+
263+
chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
264+
assert all(chunks)
265+
content = b"".join(chunks)
266+
assert content == expected
267+
268+
210269
async def test_write_payload_deflate_and_chunked(
211270
buf: bytearray,
212271
protocol: BaseProtocol,
@@ -243,6 +302,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof(
243302
assert content == expected
244303

245304

305+
@pytest.mark.usefixtures("enable_writelines")
306+
@pytest.mark.usefixtures("force_writelines_small_payloads")
307+
async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines(
308+
protocol: BaseProtocol,
309+
transport: asyncio.Transport,
310+
loop: asyncio.AbstractEventLoop,
311+
) -> None:
312+
expected = b"2\r\nx\x9c\r\nd\r\nKI,IL\xcdK\x01\x00\x0b@\x02\xd2\r\n0\r\n\r\n"
313+
msg = http.StreamWriter(protocol, loop)
314+
msg.enable_compression("deflate")
315+
msg.enable_chunking()
316+
await msg.write(b"data")
317+
await msg.write_eof(b"end")
318+
319+
chunks = [b"".join(c[1][0]) for c in list(transport.writelines.mock_calls)] # type: ignore[attr-defined]
320+
assert all(chunks)
321+
content = b"".join(chunks)
322+
assert content == expected
323+
324+
246325
async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
247326
protocol: BaseProtocol,
248327
transport: asyncio.Transport,
@@ -269,6 +348,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
269348
assert zlib.decompress(content) == (b"data" * 4096) + payload
270349

271350

351+
@pytest.mark.usefixtures("enable_writelines")
352+
@pytest.mark.usefixtures("force_writelines_small_payloads")
353+
async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines(
354+
protocol: BaseProtocol,
355+
transport: asyncio.Transport,
356+
loop: asyncio.AbstractEventLoop,
357+
) -> None:
358+
msg = http.StreamWriter(protocol, loop)
359+
msg.enable_compression("deflate")
360+
msg.enable_chunking()
361+
362+
await msg.write(b"data" * 4096)
363+
# This payload compresses to 1111 bytes
364+
payload = b"".join([bytes((*range(0, i), *range(i, 0, -1))) for i in range(255)])
365+
await msg.write_eof(payload)
366+
assert not transport.write.called # type: ignore[attr-defined]
367+
368+
chunks = []
369+
for write_lines_call in transport.writelines.mock_calls: # type: ignore[attr-defined]
370+
chunked_payload = list(write_lines_call[1][0])[1:]
371+
chunked_payload.pop()
372+
chunks.extend(chunked_payload)
373+
374+
assert all(chunks)
375+
content = b"".join(chunks)
376+
assert zlib.decompress(content) == (b"data" * 4096) + payload
377+
378+
272379
async def test_write_payload_deflate_compression_chunked_connection_lost(
273380
protocol: BaseProtocol,
274381
transport: asyncio.Transport,

0 commit comments

Comments
 (0)