Skip to content

Upd/hf-dataset-get-format #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,8 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i
del self._chunk_row_group_item_read_count[chunk_index][row_group_index]

# Return the specific row from the dataframe
return row_group_df.row(row_index_within_group) # type: ignore
# Note: The `named=True` argument is used to return the row as a dictionary
return row_group_df.row(row_index_within_group, named=True) # type: ignore

def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any:
"""Retrieve a dataframe row from a parquet chunk by loading the entire chunk into memory.
Expand All @@ -695,7 +696,10 @@ def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any:

if chunk_index not in self._df:
self._df[chunk_index] = pl.scan_parquet(chunk_filepath, low_memory=True).collect()
return self._df[chunk_index].row(index)

# Retrieve the specific row from the dataframe
# Note: The `named=True` argument is used to return the row as a dictionary
return self._df[chunk_index].row(index, named=True)

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
"""Delete a chunk from the local filesystem."""
Expand Down
45 changes: 19 additions & 26 deletions tests/streaming/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def test_parquet_index_write(

for i, _ds in enumerate(ds):
idx = i % 5
assert len(_ds) == 3
assert _ds[0] == pq_data["name"][idx]
assert _ds[1] == pq_data["weight"][idx]
assert _ds[2] == pq_data["height"][idx]
assert isinstance(_ds, dict)
assert _ds["name"] == pq_data["name"][idx]
assert _ds["weight"] == pq_data["weight"][idx]
assert _ds["height"] == pq_data["height"][idx]


@pytest.mark.skipif(condition=sys.platform == "win32", reason="Fails on windows and test gets cancelled")
Expand Down Expand Up @@ -168,7 +168,9 @@ def test_get_parquet_indexer_cls(pq_url, cls, expectation, monkeypatch, fsspec_m
@pytest.mark.usefixtures("clean_pq_index_cache")
@patch("litdata.utilities.parquet._HF_HUB_AVAILABLE", True)
@patch("litdata.streaming.downloader._HF_HUB_AVAILABLE", True)
def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data):
@pytest.mark.parametrize(("pre_load_chunk"), [False, True])
@pytest.mark.parametrize(("low_memory"), [False, True])
def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data, pre_load_chunk, low_memory):
hf_url = "hf://datasets/some_org/some_repo/some_path"

# Test case 1: Invalid item_loader
Expand All @@ -180,27 +182,18 @@ def test_stream_hf_parquet_dataset(monkeypatch, huggingface_hub_fs_mock, pq_data
assert len(ds) == 25 # 5 datasets for 5 loops
for i, _ds in enumerate(ds):
idx = i % 5
assert len(_ds) == 3
assert _ds[0] == pq_data["name"][idx]
assert _ds[1] == pq_data["weight"][idx]
assert _ds[2] == pq_data["height"][idx]

# Test case 3: Streaming with ParquetLoader as item_loader and low_memory=False
ds = StreamingDataset(hf_url, item_loader=ParquetLoader(low_memory=False))
assert len(ds) == 25
for i, _ds in enumerate(ds):
idx = i % 5
assert len(_ds) == 3
assert _ds[0] == pq_data["name"][idx]
assert _ds[1] == pq_data["weight"][idx]
assert _ds[2] == pq_data["height"][idx]

# Test case 4: Streaming with ParquetLoader and low_memory=True
ds = StreamingDataset(hf_url, item_loader=ParquetLoader(low_memory=True))
assert isinstance(_ds, dict)
assert _ds["name"] == pq_data["name"][idx]
assert _ds["weight"] == pq_data["weight"][idx]
assert _ds["height"] == pq_data["height"][idx]

# Test case 3: Streaming with passing item_loader
print("pre_load_chunk", pre_load_chunk, "low_memory", low_memory)
ds = StreamingDataset(hf_url, item_loader=ParquetLoader(pre_load_chunk, low_memory))
assert len(ds) == 25
for i, _ds in enumerate(ds):
idx = i % 5
assert len(_ds) == 3
assert _ds[0] == pq_data["name"][idx]
assert _ds[1] == pq_data["weight"][idx]
assert _ds[2] == pq_data["height"][idx]
assert isinstance(_ds, dict)
assert _ds["name"] == pq_data["name"][idx]
assert _ds["weight"] == pq_data["weight"][idx]
assert _ds["height"] == pq_data["height"][idx]
Loading