Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions benchmarks/benchmark_zarr_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
Benchmark Zarr loading via 🤗 Datasets.

This benchmark creates a small local Zarr store and compares:
- streaming=True (iterates directly from the Zarr store)
- streaming=False (builds an Arrow dataset first, then iterates)

Note: This is a best-effort benchmark intended for development/profiling.
"""


import argparse
import tempfile
import time
from pathlib import Path

import numpy as np

from datasets import load_dataset


def make_zarr_store(root: Path, n_rows: int, n_cols: int) -> Path:
import zarr

store_dir = root / "bench.zarr"
g = zarr.open_group(store=str(store_dir), mode="w")
g.create_array("x", data=np.arange(n_rows, dtype=np.int32), chunks=(min(8192, n_rows),))
g.create_array("y", data=np.random.randn(n_rows, n_cols).astype(np.float32), chunks=(min(1024, n_rows), n_cols))
return store_dir / "zarr.json"


def bench(streaming: bool, zarr_json: str, n_take: int) -> float:
ds = load_dataset("zarr", data_files=[zarr_json], split="train", streaming=streaming)
t0 = time.perf_counter()
it = iter(ds)
for _ in range(n_take):
next(it)
t1 = time.perf_counter()
return t1 - t0


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--rows", type=int, default=200_000)
parser.add_argument("--cols", type=int, default=64)
parser.add_argument("--take", type=int, default=10_000)
args = parser.parse_args()

try:
import zarr # noqa: F401
except Exception as e:
raise SystemExit("This benchmark requires `zarr` (pip install zarr).") from e

with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
zarr_json = make_zarr_store(tmp, n_rows=args.rows, n_cols=args.cols)

t_stream = bench(streaming=True, zarr_json=str(zarr_json), n_take=args.take)
t_nonstream = bench(streaming=False, zarr_json=str(zarr_json), n_take=args.take)

print(f"rows={args.rows} cols={args.cols} take={args.take}")
print(f"streaming=True : {t_stream:.3f}s ({args.take / max(t_stream, 1e-9):.1f} ex/s)")
print(f"streaming=False : {t_nonstream:.3f}s ({args.take / max(t_nonstream, 1e-9):.1f} ex/s)")


if __name__ == "__main__":
main()
44 changes: 44 additions & 0 deletions docs/source/stream.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,50 @@ Parquet is a columnar format that allows you to stream and load only a subset of
'language_score': 0.9900368452072144, 'token_count': 716}
```

## Streaming scientific formats (HDF5 and Zarr)

Scientific datasets are often stored in formats like HDF5 or Zarr (chunked, N-dimensional arrays). 🤗 Datasets can stream these formats directly from local files (and remote filesystems supported by `fsspec`) without converting the full dataset to Arrow first.

### HDF5

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("hdf5", data_files=["/path/to/data.h5"], split="train", streaming=True)
>>> print(next(iter(ds)))
{'temperature': 280.12, 'pressure': 101325.0, ...}
```

### Zarr

Zarr support is currently experimental. Install it with `pip install "datasets[zarr]"` (or `pip install zarr`).

Zarr stores are directory-based. You can point `data_files` to either the Zarr store root directory (recommended for convenience) or the Zarr root metadata file:

- Zarr store root directory: `.../store.zarr` (auto-detects metadata)
- Zarr v3: `.../store.zarr/zarr.json`
- Zarr v2 (consolidated): `.../store.zarr/.zmetadata`

```py
>>> from datasets import load_dataset
>>> ds = load_dataset("zarr", data_files=["/path/to/store.zarr"], split="train", streaming=True)
>>> print(next(iter(ds)))
{'int32': 0, 'float32': 0.0, 'matrix_2d': [[...], ...]}
```

You can also load from the Hub via the `hf://` protocol:

```py
>>> from datasets import DownloadConfig, load_dataset
>>> download_config = DownloadConfig(storage_options={"hf": {"token": None}}) # set token for private/gated repos
>>> ds = load_dataset(
... "zarr",
... data_files=["hf://datasets/<user>/<repo>@main/path/to/store.zarr/zarr.json"],
... split="train",
... streaming=True,
... download_config=download_config,
... )
```

Loading a dataset in streaming mode creates a new dataset type instance (instead of the classic [`Dataset`] object), known as an [`IterableDataset`].
This special type of dataset has its own set of processing methods shown below.

Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
"elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551.
"faiss-cpu>=1.8.0.post1", # Pins numpy < 2
"h5py",
"zarr",
"pylance",
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
Expand Down Expand Up @@ -220,6 +221,7 @@
"tensorflow_gpu": ["tensorflow>=2.6.0"],
"torch": ["torch"],
"jax": ["jax>=0.3.14", "jaxlib>=0.3.14"],
"zarr": ["zarr"],
"streaming": [], # for backward compatibility
"dev": TESTS_REQUIRE + QUALITY_REQUIRE + DOCS_REQUIRE,
"tests": TESTS_REQUIRE,
Expand Down
5 changes: 5 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .videofolder import videofolder
from .webdataset import webdataset
from .xml import xml
from .zarr import zarr


def _hash_python_lines(lines: list[str]) -> str:
Expand Down Expand Up @@ -55,6 +56,7 @@ def _hash_python_lines(lines: list[str]) -> str:
"hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())),
"eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())),
"lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())),
"zarr": (zarr.__name__, _hash_python_lines(inspect.getsource(zarr).splitlines())),
}

# get importable module names and hash for caching
Expand Down Expand Up @@ -88,6 +90,9 @@ def _hash_python_lines(lines: list[str]) -> str:
".h5": ("hdf5", {}),
".eval": ("eval", {}),
".lance": ("lance", {}),
# Zarr stores are directory-based; users typically pass the root metadata file (Zarr v3: `zarr.json`,
# Zarr v2 consolidated: `.zmetadata`) explicitly via `data_files`.
".zarr": ("zarr", {}),
}
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/zarr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Zarr packaged module for 🤗 Datasets."""

from .zarr import Zarr, ZarrConfig # noqa: F401
Loading