Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,44 @@ dataset = StreamingDatasetWithTransform(data_dir, cache_dir=str(cache_dir), shuf

</details>


<details>
<summary> ✅ Multi-Sample Transform datasets while Streaming <a id="multi-sample" href="#multi-sample">🔗</a> </summary>
&nbsp;

Sometimes you need to return a sub-sample batch for a given batch while adding subtle variations to the samples. The multi-sample feature allows you to apply multi-sample transformation while streaming, without the need to store intermediate results.

```python
def transform_fn(x, sample_idx):
"""
Apply different rotation for each sample based on sample_idx.
"""

angles = [0, 15, -15, 30]
angle = angles[sample_idx % len(angles)]

torch_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.Lambda(lambda x: transforms.functional.rotate(x, angle)), # apply rotation
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return torch_transform(x)

dataset = StreamingDataset(
data_dir,
cache_dir=str(cache_dir),
shuffle=False,
transform=[transform_fn],
sample_count=4 # Generate 4 transformed samples per input
)
```

</details>

<details>
<summary> ✅ Split datasets for train, val, test</summary>

Expand Down
44 changes: 39 additions & 5 deletions src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import logging
import os
from inspect import signature
from time import time
from typing import Any, Callable, Optional, Union

Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
index_path: Optional[str] = None,
force_override_state_dict: bool = False,
transform: Optional[Union[Callable, list[Callable]]] = None,
sample_count: int = 1,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

Expand Down Expand Up @@ -89,6 +91,7 @@ def __init__(
If `index_path` is a full file path, it will use that directly.
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
transform: Optional transformation function or list of functions to apply to each item in the dataset.
sample_count: Number of samples to return for each index access.
"""
_check_version_and_prompt_upgrade(__version__)

Expand Down Expand Up @@ -202,12 +205,27 @@ def __init__(
self.storage_options = storage_options
self.session_options = session_options
self.max_pre_download = max_pre_download
self.sample_count = sample_count
if transform is not None:
transform = transform if isinstance(transform, list) else [transform]
for t in transform:
if not callable(t):
raise ValueError(f"Transform should be a callable. Found {t}")
self.transform = transform

# define invalid transform conditions for multisample case
invalid_transform = self.sample_count > 1 and (
not hasattr(self, "transform")
or len(self.transform) > 1
or "sample_idx" not in signature(self.transform[0]).parameters
)
if invalid_transform:
logger.warning(
"Invalid transform configuration detected. "
"Either no transform, multiple transforms, or missing `sample_idx` parameter. "
"Reverting `sample_count` to 1 and returning data as-is."
)
self.sample_count = 1
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache

@property
Expand Down Expand Up @@ -282,7 +300,7 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)

def __len__(self) -> int:
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1) * self.sample_count

def set_batch_size(self, batch_size: int) -> None:
self.batch_size = batch_size
Expand Down Expand Up @@ -324,7 +342,7 @@ def __iter__(self) -> "StreamingDataset":
self.worker_intervals = workers_intervals[worker_rank]

# The max number of samples to return from `__next__` (in worker)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.sample_count

# Handle restart
if self._state_dict:
Expand Down Expand Up @@ -407,7 +425,9 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])

# replay the indexes for the current chunks
interval = self.worker_intervals[self.worker_next_chunk_index]
current_indexes = np.arange(interval[1], interval[2])

# multiply the interval by the sample_count for multisample case
current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count)

# re-shuffle the indexes
current_indexes = self.shuffler(
Expand All @@ -424,6 +444,17 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
self.worker_next_chunk_index += 1

def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
# Deflate index for multisample case
if self.sample_count > 1:
if isinstance(index, int):
sample_idx = index % self.sample_count
index = index // self.sample_count
elif isinstance(index, ChunkedIndex):
sample_idx = index.index % self.sample_count
index.index = index.index // self.sample_count
else:
raise ValueError("Slices are not supported when using `sample_count > 1`.")

if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
Expand All @@ -437,10 +468,11 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
_my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices]
item = self.cache[index]

if hasattr(self, "transform"):
if isinstance(self.transform, list):
for transform_fn in self.transform:
item = transform_fn(item)
item = transform_fn(item) if self.sample_count == 1 else transform_fn(item, sample_idx)
else:
item = self.transform(item)

Expand Down Expand Up @@ -476,7 +508,9 @@ def __next__(self) -> Any:

# `next_worker_chunks_index` is the index of the chunk that we will be working on now
interval = self.worker_intervals[self.worker_next_chunk_index]
current_indexes = np.arange(interval[1], interval[2])

# current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor)
current_indexes = np.arange(interval[1] * self.sample_count, interval[2] * self.sample_count)

assert self.shuffler is not None
assert self.num_chunks is not None
Expand Down
124 changes: 124 additions & 0 deletions tests/streaming/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import sys

Expand Down Expand Up @@ -496,3 +497,126 @@ def test_dataloader_dataset_transform_inheritance(tmpdir, shuffle):
# Verify that the transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i * 2, f"Expected {i * 2}, got {item}"


# Define a simple transform function
def multisample_transform_fn(x, sample_idx, *args, **kwargs):
"""A simple transform function that doubles the input."""
return x * sample_idx


def test_dataloader_dataset_transform_multisample(tmpdir):
"""Test if the dataset's transform is applied correctly with dataloader."""
# Create a simple dataset
# Create directories for cache and data
cache_dir = os.path.join(tmpdir, "cache_dir")
data_dir = os.path.join(tmpdir, "data_dir")
os.makedirs(cache_dir)
os.makedirs(data_dir)

# Create a dataset with 100 items, 20 items per chunk
cache = Cache(str(data_dir), chunk_size=20)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

dataset = StreamingDataset(
data_dir, cache_dir=str(cache_dir), shuffle=False, transform=multisample_transform_fn, sample_count=3
)
dataset_length = len(dataset)
assert dataset_length == 300

# ACT
dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)

complete_data = []
for batch in dl:
complete_data.extend(batch)

# ASSERT
# Verify that the multisample transform is applied correctly
for i, item in enumerate(complete_data):
if i % 3 == 0:
assert item == (i // 3) * 0, f"Expected {i * 0}, got {item}"
elif i % 3 == 1:
assert item == (i // 3) * 1, f"Expected {i * 1}, got {item}"
else:
assert item == (i // 3) * 2, f"Expected {i * 2}, got {item}"


# Define simple transform functions
def transform_fn_sq(x, sample_idx):
"""A simple transform function that doubles the input."""
return x * sample_idx


def transform_fn_add(x, sample_idx):
"""A simple transform function that adds the sample_idx to the input."""
return x + sample_idx


def transform_fn_no_sample_idx(x):
"""A simple transform function that doubles the input."""
return x


def test_dataloader_dataset_transform_invalid_config(tmpdir, caplog):
"""Test if the dataset's transform is applied correctly with dataloader."""
# Create a simple dataset
# Create directories for cache and data
cache_dir = os.path.join(tmpdir, "cache_dir")
data_dir = os.path.join(tmpdir, "data_dir")
os.makedirs(cache_dir)
os.makedirs(data_dir)

# Create a dataset with 100 items, 20 items per chunk
cache = Cache(str(data_dir), chunk_size=20)
for i in range(100):
cache[i] = i
cache.done()
cache.merge()

# Verify that logger warning happens when transform is not given
with caplog.at_level(logging.WARNING):
dataset = StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4)

assert "Invalid transform configuration detected." in caplog.text
dataset_length = len(dataset)
assert dataset_length == 100

# Verify that logger warning happens when multiple transforms are given
with caplog.at_level(logging.WARNING):
dataset = StreamingDataset(
data_dir,
cache_dir=str(cache_dir),
shuffle=False,
sample_count=4,
transform=[transform_fn_sq, transform_fn_add],
)

assert "Invalid transform configuration detected." in caplog.text
dataset_length = len(dataset)
assert dataset_length == 100

# Verify that logger warning happens when sample_idx parameter is missing
with caplog.at_level(logging.WARNING):
dataset = StreamingDataset(
data_dir, cache_dir=str(cache_dir), shuffle=False, sample_count=4, transform=transform_fn_no_sample_idx
)

assert "Invalid transform configuration detected." in caplog.text
dataset_length = len(dataset)
assert dataset_length == 100

# ACT
dl = StreamingDataLoader(dataset, batch_size=10, num_workers=1, shuffle=False)

complete_data = []
for batch in dl:
complete_data.extend(batch)

# ASSERT
# Verify that the multisample transform is applied correctly
for i, item in enumerate(complete_data):
assert item == i, f"Expected {i}, got {item}"
Loading
Loading