Skip to content

Commit 6f86e35

Browse files
committed
changes for feat:add-parquet-support PR review
1 parent 7cbb3ef commit 6f86e35

File tree

4 files changed

+34
-12
lines changed

4 files changed

+34
-12
lines changed

README.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,18 @@ The `overwrite` mode will delete the existing data and start from fresh.
643643
</details>
644644

645645
<details>
646-
<summary> ✅ Index parquet datasets</summary>
646+
<summary> ✅ Stream parquet datasets</summary>
647647
&nbsp;
648648

649-
If your dataset is already in Parquet format, you can index it directly and use it with StreamingDataset & DataLoader.
649+
You can stream Parquet datasets directly without the need to convert them into the LitData optimized binary format.
650+
651+
If your dataset is already in Parquet format, you can index and use it with StreamingDataset and DataLoader for efficient streaming.
650652

651653
Assumption:
652654
Your dataset directory contains one or more Parquet files.
653655

656+
- **Index Parquet dataset**:
657+
654658
```python
655659
import litdata as ld
656660

@@ -659,6 +663,8 @@ pq_data_uri = "gs://deep-litdata-parquet/my-parquet-data"
659663
ld.index_parquet_dataset(pq_data_uri)
660664
```
661665

666+
- **Stream the dataset with `StreamingDataset` and `ParquetLoader`**
667+
662668
When using a Streaming Dataset, ensure you use `ParquetLoader`:
663669

664670
```python

src/litdata/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
3737
_TQDM_AVAILABLE = RequirementCache("tqdm")
3838
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
39-
_POLARS_AVAILABLE = RequirementCache("polars")
39+
_POLARS_AVAILABLE = RequirementCache("polars>1.0.0")
4040
_DEBUG = bool(int(os.getenv("DEBUG", "1")))
4141

4242
_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))

src/litdata/streaming/item_loader.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> T
423423
class ParquetLoader(BaseItemLoader):
424424
def __init__(self) -> None:
425425
if not _POLARS_AVAILABLE:
426-
raise ModuleNotFoundError("Please, run: `pip install polars`")
426+
raise ModuleNotFoundError(
427+
"You are using the Parquet item loader, which depends on `Polars > 1.0.0`.",
428+
"Please, run: `pip install polars>1.0.0`",
429+
)
427430
self._chunk_filepaths: Dict[str, bool] = {}
428431

429432
def setup(
@@ -458,7 +461,10 @@ def generate_intervals(self) -> List[Interval]:
458461

459462
def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
460463
"""Logic to load the chunk in background to gain some time."""
461-
pass
464+
import polars as pl
465+
466+
if chunk_filepath not in self._df:
467+
self._df[chunk_filepath] = pl.scan_parquet(chunk_filepath).collect()
462468

463469
def load_item_from_chunk(
464470
self,

tests/streaming/test_writer.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import random
1717
import sys
18+
from collections import OrderedDict
1819

1920
import numpy as np
2021
import pytest
@@ -278,14 +279,16 @@ def test_parquet_index_write(tmpdir):
278279

279280
os.mkdir(os.path.join(tmpdir, "data"))
280281

282+
pq_data = OrderedDict(
283+
{
284+
"name": ["Tom", "Jerry", "Micky", "Oggy", "Doraemon"],
285+
"weight": [57.9, 72.5, 53.6, 83.1, 69.4], # (kg)
286+
"height": [1.56, 1.77, 1.65, 1.75, 1.63], # (m)
287+
}
288+
)
289+
281290
for i in range(5):
282-
df = pl.DataFrame(
283-
{
284-
"name": ["Tom", "Jerry", "Micky", "Oggy", "Doraemon"],
285-
"weight": [57.9, 72.5, 53.6, 83.1, 69.4], # (kg)
286-
"height": [1.56, 1.77, 1.65, 1.75, 1.63], # (m)
287-
}
288-
)
291+
df = pl.DataFrame(pq_data)
289292
file_path = os.path.join(tmpdir, "data", f"tmp-{i}.parquet")
290293
df.write_parquet(file_path)
291294

@@ -307,3 +310,10 @@ def test_parquet_index_write(tmpdir):
307310
ds = StreamingDataset(os.path.join(tmpdir, "data"), item_loader=ParquetLoader())
308311

309312
assert len(ds) == 25 # 5 datasets for 5 loops
313+
314+
for i, _ds in enumerate(ds):
315+
idx = i % 5
316+
assert len(_ds) == 3
317+
assert _ds[0] == pq_data["name"][idx]
318+
assert _ds[1] == pq_data["weight"][idx]
319+
assert _ds[2] == pq_data["height"][idx]

0 commit comments

Comments
 (0)