Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions tests/unit/test_spooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@

import unittest
from unittest.mock import MagicMock, patch
from trino.client import TrinoQuery, TrinoRequest, ClientSession, TrinoResult
from trino.client import SegmentIterator

class TestTrinoQueryLazy(unittest.TestCase):
def setUp(self):
self.mock_request = MagicMock(spec=TrinoRequest)
self.client_session = ClientSession("user")
self.mock_request.client_session = self.client_session

def test_fetch_returns_iterator_for_spooled_segments(self):
# Mock the initial POST response
post_response = MagicMock()
post_response.id = "query_1"
post_response.stats = {}
post_response.info_uri = "info"
post_response.next_uri = "next_1"
post_response.rows = [] # No rows initially

self.mock_request.process.return_value = post_response
self.mock_request.post.return_value = MagicMock()

query = TrinoQuery(self.mock_request, "SELECT 1")

# Execute should return empty result initially but try to fetch
# We need to mock fetch behavior too since execute calls it if rows are empty

# Mock the GET response for fetch()
get_response_status = MagicMock()
get_response_status.next_uri = None # Finished
get_response_status.stats = {}
# Status rows as dict indicates spooling protocol
get_response_status.rows = {
"encoding": "json",
"segments": [
{"type": "spooled", "uri": "u1", "ackUri": "a1", "metadata": {"segmentSize": "10", "uncompressedSize": "10"}}
],
"metadata": {}
}

# When execute calls fetch(), it calls request.get -> process -> returns get_response_status
self.mock_request.process.side_effect = [post_response, get_response_status]
self.mock_request.get.return_value = MagicMock()

# Mock _to_segments to return a list of decodable segments
# We can just verify that fetch returns a SegmentIterator
# But _to_segments is internal.

# We need to patch SegmentIterator or check the return type

result = query.execute()

# Verify result.rows is a SegmentIterator, NOT a list
self.assertIsInstance(result.rows, SegmentIterator)
self.assertNotIsInstance(result.rows, list)

def test_fetch_returns_list_for_normal_segments(self):
# Mock the initial POST response
post_response = MagicMock()
post_response.id = "query_1"
post_response.stats = {}
post_response.info_uri = "info"
post_response.next_uri = "next_1"
post_response.rows = []

# Mock the GET response for fetch()
get_response_status = MagicMock()
get_response_status.next_uri = None
get_response_status.stats = {}
get_response_status.rows = [[1], [2]] # Normal list rows

self.mock_request.process.side_effect = [post_response, get_response_status]

query = TrinoQuery(self.mock_request, "SELECT 1")
result = query.execute()

# Verify result.rows is a list (appended)
self.assertIsInstance(result.rows, list)
self.assertEqual(result.rows, [[1], [2]])

if __name__ == '__main__':
unittest.main()
34 changes: 29 additions & 5 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,9 +904,32 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
"""
Execute should block until at least one row is received or query is finished or cancelled

For Standard Execution, rows is a list, we can check len. the first response usually contains no rows (just stats),
so we need to continue fetching until we get some rows or query is finished or cancelled.

For Spooled Execution, rows start as empty list and eventually fetch returns the rows as iterator,
we can't check len of an iterator easily without peeking.

So, if we get rows as non empty list or iterator, we stop blocking and return it to the caller to consume it.
"""
Comment on lines +907 to +917
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring is too verbose, it should be a short comment as it was previously, for example:

# Execute should block until the query is finished or cancelled,
# or until at least one row is received (direct protocol),
# or an iterator is received (spooling protocol).


while not self.finished and not self.cancelled:
if isinstance(self._result.rows, list) and len(self._result.rows) == 0:
new_rows = self.fetch()
if isinstance(new_rows, list):
self._result.rows += new_rows
else:
# It's an iterator (spooled segments), replace rows with it
self._result.rows = new_rows
# We have an iterator now, so we can return result to user
break
else:
# We have data (list with items or an iterator), so return
break
Comment on lines +920 to +931
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(self._result.rows, list) and len(self._result.rows) == 0:
new_rows = self.fetch()
if isinstance(new_rows, list):
self._result.rows += new_rows
else:
# It's an iterator (spooled segments), replace rows with it
self._result.rows = new_rows
# We have an iterator now, so we can return result to user
break
else:
# We have data (list with items or an iterator), so return
break
# Stop if we have a non-empty list or an iterator
if not isinstance(self._result.rows, list) or self._result.rows:
break
new_rows = self.fetch()
if isinstance(new_rows, list):
self._result.rows.extend(new_rows)
elif isinstance(new_rows, SegmentIterator):
self._result.rows = new_rows
break
else:
raise TypeError(
f"fetch() returned {type(new_rows).__name__}, expected list or SegmentIterator"
)

This part could be made a bit more readable:

  • outer else can be avoided.
  • explicitly check for rows type. Raise an error if the type is neither list nor SegmentIterator
  • extend is more idiomatic than += for lists.
  • comments can be simplified a bit


return self._result

def _update_state(self, status):
Expand All @@ -920,7 +943,7 @@ def _update_state(self, status):
if status.columns:
self._columns = status.columns

def fetch(self) -> List[Union[List[Any]], Any]:
def fetch(self) -> Union[List[Union[List[Any], Any]], Iterator[List[Any]]]:
"""Continue fetching data for the current query_id"""
try:
response = self._request.get(self._request.next_uri)
Expand All @@ -941,7 +964,8 @@ def fetch(self) -> List[Union[List[Any]], Any]:
spooled = self._to_segments(rows)
if self._fetch_mode == "segments":
return spooled
return list(SegmentIterator(spooled, self._row_mapper))
# Return iterator directly, do NOT materialize with list()
return SegmentIterator(spooled, self._row_mapper)
elif isinstance(status.rows, list):
return self._row_mapper.map(rows)
else:
Expand Down
Loading