2
2
import array
3
3
import asyncio
4
4
import zlib
5
- from typing import Any , Iterable
5
+ from typing import Any , Generator , Iterable
6
6
from unittest import mock
7
7
8
8
import pytest
13
13
from aiohttp .test_utils import make_mocked_coro
14
14
15
15
16
+ @pytest .fixture
17
+ def enable_writelines () -> Generator [None , None , None ]:
18
+ with mock .patch ("aiohttp.http_writer.SKIP_WRITELINES" , False ):
19
+ yield
20
+
21
+
22
+ @pytest .fixture
23
+ def force_writelines_small_payloads () -> Generator [None , None , None ]:
24
+ with mock .patch ("aiohttp.http_writer.MIN_PAYLOAD_FOR_WRITELINES" , 1 ):
25
+ yield
26
+
27
+
16
28
@pytest .fixture
17
29
def buf () -> bytearray :
18
30
return bytearray ()
@@ -136,6 +148,33 @@ async def test_write_large_payload_deflate_compression_data_in_eof(
136
148
assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
137
149
138
150
151
+ @pytest .mark .usefixtures ("enable_writelines" )
152
+ async def test_write_large_payload_deflate_compression_data_in_eof_writelines (
153
+ protocol : BaseProtocol ,
154
+ transport : asyncio .Transport ,
155
+ loop : asyncio .AbstractEventLoop ,
156
+ ) -> None :
157
+ msg = http .StreamWriter (protocol , loop )
158
+ msg .enable_compression ("deflate" )
159
+
160
+ await msg .write (b"data" * 4096 )
161
+ assert transport .write .called # type: ignore[attr-defined]
162
+ chunks = [c [1 ][0 ] for c in list (transport .write .mock_calls )] # type: ignore[attr-defined]
163
+ transport .write .reset_mock () # type: ignore[attr-defined]
164
+ assert not transport .writelines .called # type: ignore[attr-defined]
165
+
166
+ # This payload compresses to 20447 bytes
167
+ payload = b"" .join (
168
+ [bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 ) for _ in range (64 )]
169
+ )
170
+ await msg .write_eof (payload )
171
+ assert not transport .write .called # type: ignore[attr-defined]
172
+ assert transport .writelines .called # type: ignore[attr-defined]
173
+ chunks .extend (transport .writelines .mock_calls [0 ][1 ][0 ]) # type: ignore[attr-defined]
174
+ content = b"" .join (chunks )
175
+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
176
+
177
+
139
178
async def test_write_payload_chunked_filter (
140
179
protocol : BaseProtocol ,
141
180
transport : asyncio .Transport ,
@@ -207,6 +246,26 @@ async def test_write_payload_deflate_compression_chunked(
207
246
assert content == expected
208
247
209
248
249
+ @pytest .mark .usefixtures ("enable_writelines" )
250
+ @pytest .mark .usefixtures ("force_writelines_small_payloads" )
251
+ async def test_write_payload_deflate_compression_chunked_writelines (
252
+ protocol : BaseProtocol ,
253
+ transport : asyncio .Transport ,
254
+ loop : asyncio .AbstractEventLoop ,
255
+ ) -> None :
256
+ expected = b"2\r \n x\x9c \r \n a\r \n KI,I\x04 \x00 \x04 \x00 \x01 \x9b \r \n 0\r \n \r \n "
257
+ msg = http .StreamWriter (protocol , loop )
258
+ msg .enable_compression ("deflate" )
259
+ msg .enable_chunking ()
260
+ await msg .write (b"data" )
261
+ await msg .write_eof ()
262
+
263
+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
264
+ assert all (chunks )
265
+ content = b"" .join (chunks )
266
+ assert content == expected
267
+
268
+
210
269
async def test_write_payload_deflate_and_chunked (
211
270
buf : bytearray ,
212
271
protocol : BaseProtocol ,
@@ -243,6 +302,26 @@ async def test_write_payload_deflate_compression_chunked_data_in_eof(
243
302
assert content == expected
244
303
245
304
305
+ @pytest .mark .usefixtures ("enable_writelines" )
306
+ @pytest .mark .usefixtures ("force_writelines_small_payloads" )
307
+ async def test_write_payload_deflate_compression_chunked_data_in_eof_writelines (
308
+ protocol : BaseProtocol ,
309
+ transport : asyncio .Transport ,
310
+ loop : asyncio .AbstractEventLoop ,
311
+ ) -> None :
312
+ expected = b"2\r \n x\x9c \r \n d\r \n KI,IL\xcd K\x01 \x00 \x0b @\x02 \xd2 \r \n 0\r \n \r \n "
313
+ msg = http .StreamWriter (protocol , loop )
314
+ msg .enable_compression ("deflate" )
315
+ msg .enable_chunking ()
316
+ await msg .write (b"data" )
317
+ await msg .write_eof (b"end" )
318
+
319
+ chunks = [b"" .join (c [1 ][0 ]) for c in list (transport .writelines .mock_calls )] # type: ignore[attr-defined]
320
+ assert all (chunks )
321
+ content = b"" .join (chunks )
322
+ assert content == expected
323
+
324
+
246
325
async def test_write_large_payload_deflate_compression_chunked_data_in_eof (
247
326
protocol : BaseProtocol ,
248
327
transport : asyncio .Transport ,
@@ -269,6 +348,34 @@ async def test_write_large_payload_deflate_compression_chunked_data_in_eof(
269
348
assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
270
349
271
350
351
+ @pytest .mark .usefixtures ("enable_writelines" )
352
+ @pytest .mark .usefixtures ("force_writelines_small_payloads" )
353
+ async def test_write_large_payload_deflate_compression_chunked_data_in_eof_writelines (
354
+ protocol : BaseProtocol ,
355
+ transport : asyncio .Transport ,
356
+ loop : asyncio .AbstractEventLoop ,
357
+ ) -> None :
358
+ msg = http .StreamWriter (protocol , loop )
359
+ msg .enable_compression ("deflate" )
360
+ msg .enable_chunking ()
361
+
362
+ await msg .write (b"data" * 4096 )
363
+ # This payload compresses to 1111 bytes
364
+ payload = b"" .join ([bytes ((* range (0 , i ), * range (i , 0 , - 1 ))) for i in range (255 )])
365
+ await msg .write_eof (payload )
366
+ assert not transport .write .called # type: ignore[attr-defined]
367
+
368
+ chunks = []
369
+ for write_lines_call in transport .writelines .mock_calls : # type: ignore[attr-defined]
370
+ chunked_payload = list (write_lines_call [1 ][0 ])[1 :]
371
+ chunked_payload .pop ()
372
+ chunks .extend (chunked_payload )
373
+
374
+ assert all (chunks )
375
+ content = b"" .join (chunks )
376
+ assert zlib .decompress (content ) == (b"data" * 4096 ) + payload
377
+
378
+
272
379
async def test_write_payload_deflate_compression_chunked_connection_lost (
273
380
protocol : BaseProtocol ,
274
381
transport : asyncio .Transport ,
0 commit comments