Skip to content

Feat: Add support for parquet files #443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6bde708
started working on adding parquet support in litdata
deependujha Jan 6, 2025
34c8a64
write_parquet_index fn working
deependujha Jan 11, 2025
c1c0f13
streaming_dataset and streaming_dataset can read optimized parquet files
deependujha Jan 12, 2025
ff40eba
fixed mypy issues
deependujha Jan 12, 2025
d53dc7d
Merge branch 'main' into feat/add-hf-parquet-support
deependujha Jan 12, 2025
3db6033
update
deependujha Jan 12, 2025
239e093
update
deependujha Jan 12, 2025
76efafb
Merge branch 'main' into feat/add-hf-parquet-support
deependujha Jan 28, 2025
09d1000
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2025
8b880ac
need to test it on s3
deependujha Jan 28, 2025
520ef21
update
deependujha Jan 28, 2025
211c987
update
deependujha Jan 28, 2025
4640756
fixed test
deependujha Jan 28, 2025
53e8995
fixed mypy error
deependujha Jan 28, 2025
75b3163
hip-hip hurray. working on google-storage
deependujha Jan 28, 2025
5a4e83b
update readme
deependujha Jan 29, 2025
7da25e0
remove assert
deependujha Jan 29, 2025
7e23490
cache parquet reads
deependujha Jan 31, 2025
95fb340
update
deependujha Jan 31, 2025
3376e0a
update
deependujha Jan 31, 2025
caabd03
made required changes
deependujha Feb 1, 2025
8791f1a
update
deependujha Feb 1, 2025
1d0b446
update
deependujha Feb 1, 2025
363529b
add type annotations
deependujha Feb 1, 2025
c55269c
update
deependujha Feb 1, 2025
48f4a45
update
deependujha Feb 1, 2025
26e9a14
remove module binding to attribute
deependujha Feb 3, 2025
59d9d14
remove state_dict method from ParquetLoader
deependujha Feb 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,37 @@ The `overwrite` mode will delete the existing data and start from fresh.

</details>

<details>
<summary> ✅ Index parquet datasets</summary>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Index" may not be immediately clear to users imo.
Ultimately what users get is the ability to "Stream Parquet datasets", I'd have this as the title. Index is a technical detail.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also add a line or two explaining how big of a deal this is : ) "Stream Parquet files directly without converting them to the LitData optimized binary format" or something of this nature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've the changes. Since this PR was already merged, so the new changes are in: PR: #460

&nbsp;

If your dataset is already in Parquet format, you can index it directly and use it with StreamingDataset & DataLoader.

Assumption:
Your dataset directory contains one or more Parquet files.

```python
import litdata as ld

pq_data_uri = "gs://deep-litdata-parquet/my-parquet-data"

ld.index_parquet_dataset(pq_data_uri)
```

When using a Streaming Dataset, ensure you use `ParquetLoader`:

```python
import litdata as ld
from litdata.streaming.item_loader import ParquetLoader

ds = ld.StreamingDataset('gs://deep-litdata-parquet/my-parquet-data', item_loader = ParquetLoader())

for _ds in ds:
print(f"{_ds=}")
```

</details>

<details>
<summary> ✅ Use compression</summary>
&nbsp;
Expand Down
1 change: 1 addition & 0 deletions requirements/extras.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pyarrow
tqdm
lightning-sdk==0.1.46 # Must be pinned to ensure compatibility
google-cloud-storage
polars
2 changes: 2 additions & 0 deletions src/litdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import TokensLoader
from litdata.streaming.writer import index_parquet_dataset
from litdata.utilities.breakpoint import breakpoint
from litdata.utilities.train_test_split import train_test_split

Expand All @@ -31,6 +32,7 @@
"walk",
"train_test_split",
"merge_datasets",
"index_parquet_dataset",
"breakpoint",
]
if RequirementCache("lightning_sdk"):
Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_AZURE_STORAGE_AVAILABLE = RequirementCache("azure.storage.blob")
_TQDM_AVAILABLE = RequirementCache("tqdm")
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
_POLARS_AVAILABLE = RequirementCache("polars")
_DEBUG = bool(int(os.getenv("DEBUG", "1")))

_MAX_WAIT_TIME = int(os.getenv("MAX_WAIT_TIME", "120"))
Expand Down
5 changes: 3 additions & 2 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def __init__(
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)

self._config["data_spec"] = treespec_loads(self._config["data_spec"])
if self._config["data_spec"] is not None:
self._config["data_spec"] = treespec_loads(self._config["data_spec"])

assert self._chunks is not None
self._item_loader.setup(self._config, self._chunks, serializers, region_of_interest)
Expand Down Expand Up @@ -229,7 +230,7 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:

filesize_bytes = chunk["chunk_bytes"]

if self._config and self._config.get("encryption") is None:
if self._config and self._config.get("encryption") is None and (not local_chunkpath.endswith(".parquet")):
filesize_bytes += (1 + chunk["chunk_size"]) * 4

return local_chunkpath, begin, filesize_bytes
Expand Down
89 changes: 88 additions & 1 deletion src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
import numpy as np
import torch

from litdata.constants import _FORCE_DOWNLOAD_TIME, _MAX_WAIT_TIME, _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.constants import (
_FORCE_DOWNLOAD_TIME,
_MAX_WAIT_TIME,
_NUMPY_DTYPES_MAPPING,
_POLARS_AVAILABLE,
_TORCH_DTYPES_MAPPING,
)
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import PyTree, tree_unflatten
from litdata.utilities.encryption import Encryption, EncryptionLevel
Expand Down Expand Up @@ -412,3 +418,84 @@ def close(self, chunk_index: int) -> None:
@classmethod
def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]:
return data[0], flattened[0].shape[0]


class ParquetLoader(BaseItemLoader):
def __init__(self) -> None:
if not _POLARS_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install polars`")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to prepend "You are using the Parquet item loader, which depends on Polars. Please run: pip install polars"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make bound checks on the version?

Copy link
Collaborator Author

@deependujha deependujha Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure on the exact version bound. To be on safer side for now, I've simply updated to polars>1.0.0.

Is this fine, or should I refine it further?

self._chunk_filepaths: Dict[str, bool] = {}

def setup(
self,
config: Dict,
chunks: List,
serializers: Dict[str, Serializer],
region_of_interest: Optional[List[Tuple[int, int]]] = None,
) -> None:
self._config = config
self._chunks = chunks
self._serializers = {**serializers}
self._data_format = self._config["data_format"]
self._shift_idx = len(self._data_format) * 4
self.region_of_interest = region_of_interest
self._df: Dict[str, Any] = {}

def generate_intervals(self) -> List[Interval]:
intervals = []
begin = 0
end = 0
for idx, curr_chunk in enumerate(self._chunks):
end += curr_chunk["chunk_size"]
start_idx, end_idx = begin, end
if self.region_of_interest is not None:
start_idx = begin + self.region_of_interest[idx][0]
end_idx = begin + self.region_of_interest[idx][1]

intervals.append(Interval(begin, start_idx, end_idx, end))
begin += curr_chunk["chunk_size"]
return intervals

def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a fundamental reason why we're not pre-loading or is it just for sequencing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the oversight. Thanks for pointing that out!

I've made the necessary changes now to include pre-loading as suggested.

"""Logic to load the chunk in background to gain some time."""
pass

def load_item_from_chunk(
self,
index: int,
chunk_index: int,
chunk_filepath: str,
begin: int,
filesize_bytes: int,
) -> Any:
"""Returns an item loaded from a chunk."""
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
del self._chunk_filepaths[chunk_filepath]

if chunk_filepath not in self._chunk_filepaths:
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes

while not exists:
sleep(0.1)
exists = os.path.exists(chunk_filepath) and os.stat(chunk_filepath).st_size >= filesize_bytes

self._chunk_filepaths[chunk_filepath] = True

return self.get_df(chunk_filepath).row(index - begin)

def get_df(self, chunk_filepath: str) -> Any:
import polars as pl

if chunk_filepath not in self._df:
self._df[chunk_filepath] = pl.scan_parquet(chunk_filepath).collect()
return self._df[chunk_filepath]

def delete(self, chunk_index: int, chunk_filepath: str) -> None:
"""Delete a chunk from the local filesystem."""
if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)
if chunk_filepath in self._df:
del self._df[chunk_filepath]

def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any:
pass
20 changes: 20 additions & 0 deletions src/litdata/streaming/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@

@dataclass
class ChunkedIndex:
"""Represents an index within a chunked dataset.

Attributes:
index (int): The global index of the data point across all chunks.
chunk_index (int): The index of the chunk where the data point resides.
chunk_indexes (Optional[List[int]]): A list specifying the range of indexes
allowed to read for this data point, in the form
[start, can_read_from, can_read_till, end]. Defaults to None.
is_last_index (bool): Indicates whether this is the last index in the dataset.
Defaults to False.

Suppose there are 3 chunk files:
- chunk-0.bin: Contains data points with global indexes 0-4.
- chunk-1.bin: Contains data points with global indexes 5-9.
- chunk-2.bin: Contains data points with global indexes 10-14.

A `ChunkedIndex` instance for the 6th data point (global index 5) would look like:
ChunkedIndex(index=5, chunk_index=1, chunk_indexes=[4,4,8,8], is_last_index=False)
"""

index: int
chunk_index: int
chunk_indexes: Optional[List[int]] = None
Expand Down
50 changes: 48 additions & 2 deletions src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@

import numpy as np

from litdata.constants import _INDEX_FILENAME
from litdata.constants import _INDEX_FILENAME, _POLARS_AVAILABLE
from litdata.processing.utilities import get_worker_rank
from litdata.streaming.compression import _COMPRESSORS, Compressor
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader
from litdata.streaming.item_loader import BaseItemLoader, ParquetLoader, PyTreeLoader
from litdata.streaming.serializers import Serializer, _get_serializers
from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps
from litdata.utilities.encryption import Encryption, EncryptionLevel
from litdata.utilities.env import _DistributedEnv, _WorkerEnv
from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes
from litdata.utilities.parquet import get_parquet_indexer_cls


@dataclass
Expand Down Expand Up @@ -530,3 +531,48 @@ def save_checkpoint(self, checkpoint_dir: str = ".checkpoints") -> Optional[str]
json.dump(checkPoint, f)

return checkpoint_filepath


def index_parquet_dataset(
pq_dir_url: str, cache_dir: Optional[str] = None, storage_options: Optional[Dict] = {}
) -> None:
if not _POLARS_AVAILABLE:
raise ModuleNotFoundError("Please, run: `pip install polars`")

import polars as pl

pq_chunks_info = []
config: Dict[str, Any] = {
"compression": None,
"chunk_size": None,
"chunk_bytes": None,
"data_format": [],
"data_spec": None,
"encryption": None,
"item_loader": ParquetLoader.__name__,
}

pq_dir_class = get_parquet_indexer_cls(pq_dir_url, cache_dir, storage_options)
# iterate the directory and for all files ending in `.parquet` index them
for file_name, file_path in pq_dir_class:
file_size = os.path.getsize(file_path)
pq_polars = pl.scan_parquet(file_path)
chunk_dtypes = pq_polars.collect_schema().dtypes()
chunk_dtypes = [str(dt) for dt in chunk_dtypes]
chunk_size = pq_polars.select(pl.count()).collect().item()

if len(config["data_format"]) != 0 and config["data_format"] != chunk_dtypes:
raise Exception(
"The config isn't consistent between chunks. This shouldn't have happened."
f"Found {config}; {chunk_dtypes}."
)
config["data_format"] = chunk_dtypes
chunk_info = {
"chunk_bytes": file_size,
"chunk_size": chunk_size,
"filename": file_name,
"dim": None,
}
pq_chunks_info.append(chunk_info)

pq_dir_class.write_index(pq_chunks_info, config)
Loading
Loading