Skip to content

Commit 1c8ab3f

Browse files
authored
Upd/hf-dataset-get-format (#522)
* fix: update row retrieval to return rows as dictionaries in ParquetLoader * fix: update tests to assert dictionary structure for parquet dataset rows
1 parent 3602a36 commit 1c8ab3f

File tree

2 files changed

+25
-28
lines changed

2 files changed

+25
-28
lines changed

src/litdata/streaming/item_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,8 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i
674674
del self._chunk_row_group_item_read_count[chunk_index][row_group_index]
675675

676676
# Return the specific row from the dataframe
677-
return row_group_df.row(row_index_within_group) # type: ignore
677+
# Note: The `named=True` argument is used to return the row as a dictionary
678+
return row_group_df.row(row_index_within_group, named=True) # type: ignore
678679

679680
def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any:
680681
"""Retrieve a dataframe row from a parquet chunk by loading the entire chunk into memory.
@@ -695,7 +696,10 @@ def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any:
695696

696697
if chunk_index not in self._df:
697698
self._df[chunk_index] = pl.scan_parquet(chunk_filepath, low_memory=True).collect()
698-
return self._df[chunk_index].row(index)
699+
700+
# Retrieve the specific row from the dataframe
701+
# Note: The `named=True` argument is used to return the row as a dictionary
702+
return self._df[chunk_index].row(index, named=True)
699703

700704
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
701705
"""Delete a chunk from the local filesystem."""

tests/streaming/test_parquet.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def test_parquet_index_write(
8181

8282
for i, _ds in enumerate(ds):
8383
idx = i % 5
84-
assert len(_ds) == 3
85-
assert _ds[0] == pq_data["name"][idx]
86-
assert _ds[1] == pq_data["weight"][idx]
87-
assert _ds[2] == pq_data["height"][idx]
84+
assert isinstance(_ds, dict)
85+
assert _ds["name"] == pq_data["name"][idx]
86+
assert _ds["weight"] == pq_data["weight"][idx]
87+
assert _ds["height"] == pq_data["height"][idx]
8888

8989

9090
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Fails on windows and test gets cancelled")
@@ -168,7 +168,9 @@ def test_get_parquet_indexer_cls(pq_url, cls, expectation, monkeypatch, fsspec_m
168168
@pytest.mark.usefixtures("clean_pq_index_cache")
169169
@patch("litdata.utilities.parquet._HF_HUB_AVAILABLE", True)
170170
@patch("litdata.streaming.downloader._HF_HUB_AVAILABLE", True)
171-
def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data):
171+
@pytest.mark.parametrize(("pre_load_chunk"), [False, True])
172+
@pytest.mark.parametrize(("low_memory"), [False, True])
173+
def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data, pre_load_chunk, low_memory):
172174
hf_url = "hf://datasets/some_org/some_repo/some_path"
173175

174176
# Test case 1: Invalid item_loader
@@ -180,27 +182,18 @@ def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data
180182
assert len(ds) == 25 # 5 datasets for 5 loops
181183
for i, _ds in enumerate(ds):
182184
idx = i % 5
183-
assert len(_ds) == 3
184-
assert _ds[0] == pq_data["name"][idx]
185-
assert _ds[1] == pq_data["weight"][idx]
186-
assert _ds[2] == pq_data["height"][idx]
187-
188-
# Test case 3: Streaming with ParquetLoader as item_loader and low_memory=False
189-
ds = StreamingDataset(hf_url, item_loader=ParquetLoader(low_memory=False))
190-
assert len(ds) == 25
191-
for i, _ds in enumerate(ds):
192-
idx = i % 5
193-
assert len(_ds) == 3
194-
assert _ds[0] == pq_data["name"][idx]
195-
assert _ds[1] == pq_data["weight"][idx]
196-
assert _ds[2] == pq_data["height"][idx]
197-
198-
# Test case 4: Streaming with ParquetLoader and low_memory=True
199-
ds = StreamingDataset(hf_url, item_loader=ParquetLoader(low_memory=True))
185+
assert isinstance(_ds, dict)
186+
assert _ds["name"] == pq_data["name"][idx]
187+
assert _ds["weight"] == pq_data["weight"][idx]
188+
assert _ds["height"] == pq_data["height"][idx]
189+
190+
# Test case 3: Streaming with passing item_loader
191+
print("pre_load_chunk", pre_load_chunk, "low_memory", low_memory)
192+
ds = StreamingDataset(hf_url, item_loader=ParquetLoader(pre_load_chunk, low_memory))
200193
assert len(ds) == 25
201194
for i, _ds in enumerate(ds):
202195
idx = i % 5
203-
assert len(_ds) == 3
204-
assert _ds[0] == pq_data["name"][idx]
205-
assert _ds[1] == pq_data["weight"][idx]
206-
assert _ds[2] == pq_data["height"][idx]
196+
assert isinstance(_ds, dict)
197+
assert _ds["name"] == pq_data["name"][idx]
198+
assert _ds["weight"] == pq_data["weight"][idx]
199+
assert _ds["height"] == pq_data["height"][idx]

0 commit comments

Comments
 (0)