Skip to content

Commit 2143e35

Browse files
[Feature] Read streams by 1MB chunks by default. (#817)
## What changes are proposed in this pull request? This PR changes the `_BaseClient` to read streams by chunks of 1MB by default. 1MB was chosen as a good compromise between speed and memory usage (see PR #319). Note that this is not a new feature per se as it was possible to configure chunk size on the returned `_StreamResponse` before calling its read method. However, the functionality was not easy to discover and led several users to experience memory issues. The new default behavior is more defensive. ## How is this tested? Added a few test cases to verify that streams are chunked as expected. --------- Signed-off-by: Renaud Hartert <renaud.hartert@databricks.com>
1 parent f7f9a68 commit 2143e35

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

databricks/sdk/_base_client.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def __init__(self,
5050
http_timeout_seconds: float = None,
5151
extra_error_customizers: List[_ErrorCustomizer] = None,
5252
debug_headers: bool = False,
53-
clock: Clock = None):
53+
clock: Clock = None,
54+
streaming_buffer_size: int = 1024 * 1024): # 1MB
5455
"""
5556
:param debug_truncate_bytes:
5657
:param retry_timeout_seconds:
@@ -68,6 +69,7 @@ def __init__(self,
6869
:param extra_error_customizers:
6970
:param debug_headers: Whether to include debug headers in the request log.
7071
:param clock: Clock object to use for time-related operations.
72+
:param streaming_buffer_size: The size of the buffer to use for streaming responses.
7173
"""
7274

7375
self._debug_truncate_bytes = debug_truncate_bytes or 96
@@ -78,6 +80,7 @@ def __init__(self,
7880
self._clock = clock or RealClock()
7981
self._session = requests.Session()
8082
self._session.auth = self._authenticate
83+
self._streaming_buffer_size = streaming_buffer_size
8184

8285
# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
8386
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
@@ -158,7 +161,9 @@ def do(self,
158161
for header in response_headers if response_headers else []:
159162
resp[header] = response.headers.get(Casing.to_header_case(header))
160163
if raw:
161-
resp["contents"] = _StreamingResponse(response)
164+
streaming_response = _StreamingResponse(response)
165+
streaming_response.set_chunk_size(self._streaming_buffer_size)
166+
resp["contents"] = streaming_response
162167
return resp
163168
if not len(response.content):
164169
return resp
@@ -283,6 +288,11 @@ def isatty(self) -> bool:
283288
return False
284289

285290
def read(self, n: int = -1) -> bytes:
291+
"""
292+
Read up to n bytes from the response stream. If n is negative, read
293+
until the end of the stream.
294+
"""
295+
286296
self._open()
287297
read_everything = n < 0
288298
remaining_bytes = n

tests/test_base_client.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import random
12
from http.server import BaseHTTPRequestHandler
23
from typing import Iterator, List
4+
from unittest.mock import Mock
35

46
import pytest
57
import requests
@@ -276,3 +278,39 @@ def inner(h: BaseHTTPRequestHandler):
276278
assert 'foo' in res
277279

278280
assert len(requests) == 2
281+
282+
283+
@pytest.mark.parametrize('chunk_size,expected_chunks,data_size',
284+
[(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
285+
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
286+
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
287+
])
288+
def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
289+
rng = random.Random(42)
290+
test_data = bytes(rng.getrandbits(8) for _ in range(data_size))
291+
292+
content_chunks = []
293+
mock_response = Mock(spec=requests.Response)
294+
295+
def mock_iter_content(chunk_size):
296+
# Simulate how requests would chunk the data.
297+
for i in range(0, len(test_data), chunk_size):
298+
chunk = test_data[i:i + chunk_size]
299+
content_chunks.append(chunk) # track chunks for verification
300+
yield chunk
301+
302+
mock_response.iter_content = mock_iter_content
303+
stream = _StreamingResponse(mock_response)
304+
stream.set_chunk_size(chunk_size)
305+
306+
# Read all data one byte at a time.
307+
received_data = b""
308+
while True:
309+
chunk = stream.read(1)
310+
if not chunk:
311+
break
312+
received_data += chunk
313+
314+
assert received_data == test_data # all data was received correctly
315+
assert len(content_chunks) == expected_chunks # correct number of chunks
316+
assert all(len(c) <= chunk_size for c in content_chunks) # chunks don't exceed size

0 commit comments

Comments
 (0)