Skip to content

Commit 6a77302

Browse files
VijayVignesh1VijayVignesh2025
authored andcommitted
Adding multisample feature along with testcases
1 parent 359bbf1 commit 6a77302

File tree

2 files changed

+150
-6
lines changed

2 files changed

+150
-6
lines changed

src/litdata/streaming/dataset.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
index_path: Optional[str] = None,
6363
force_override_state_dict: bool = False,
6464
transform: Optional[Union[Callable, list[Callable]]] = None,
65+
is_multisample: bool = False,
6566
) -> None:
6667
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
6768
@@ -89,6 +90,7 @@ def __init__(
8990
If `index_path` is a full file path, it will use that directly.
9091
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
9192
transform: Optional transformation function or list of functions to apply to each item in the dataset.
93+
is_multisample: If True, each index access returns multiple samples transformed by the list of functions.
9294
"""
9395
_check_version_and_prompt_upgrade(__version__)
9496

@@ -209,6 +211,9 @@ def __init__(
209211
raise ValueError(f"Transform should be a callable. Found {t}")
210212
self.transform = transform
211213
self._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache
214+
self.is_multisample = is_multisample
215+
if self.is_multisample and not transform:
216+
raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.")
212217

213218
@property
214219
def on_demand_bytes(self) -> bool:
@@ -282,7 +287,8 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
282287
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)
283288

284289
def __len__(self) -> int:
285-
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
290+
original_len = self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
291+
return original_len if not self.is_multisample else original_len * len(self.transform)
286292

287293
def set_batch_size(self, batch_size: int) -> None:
288294
self.batch_size = batch_size
@@ -323,8 +329,13 @@ def __iter__(self) -> "StreamingDataset":
323329
self.worker_chunks = workers_chunks[worker_rank]
324330
self.worker_intervals = workers_intervals[worker_rank]
325331

332+
# multiply the interval by the multisample factor if multisampling is enabled
333+
self.multisample_factor = len(self.transform) if self.is_multisample else 1
334+
326335
# The max number of samples to return from `__next__` (in worker)
327-
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
336+
self.stop_length = (
337+
sum(interval[2] - interval[1] for interval in self.worker_intervals) * self.multisample_factor
338+
)
328339

329340
# Handle restart
330341
if self._state_dict:
@@ -407,7 +418,8 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
407418

408419
# replay the indexes for the current chunks
409420
interval = self.worker_intervals[self.worker_next_chunk_index]
410-
current_indexes = np.arange(interval[1], interval[2])
421+
# multiply the interval by the multisample factor if multisampling is enabled
422+
current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor)
411423

412424
# re-shuffle the indexes
413425
current_indexes = self.shuffler(
@@ -424,6 +436,21 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
424436
self.worker_next_chunk_index += 1
425437

426438
def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
439+
# Deflate index for multisample case
440+
if self.is_multisample:
441+
if not self.transform:
442+
raise ValueError("When using `is_multisample=True`, `transform` must be a list of callables.")
443+
if not all(callable(fn) for fn in self.transform):
444+
raise ValueError("All elements in `transform` must be callable when using `is_multisample=True`.")
445+
if isinstance(index, int):
446+
sample_idx = index % len(self.transform)
447+
index = index // len(self.transform)
448+
elif isinstance(index, ChunkedIndex):
449+
sample_idx = index.index % len(self.transform)
450+
index.index = index.index // len(self.transform)
451+
else:
452+
raise ValueError("Slices are not supported when using `is_multisample=True`.")
453+
427454
if self.cache is None:
428455
self.worker_env = _WorkerEnv.detect()
429456
self.cache = self._create_cache(worker_env=self.worker_env)
@@ -437,16 +464,21 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
437464
_my_cache_indices = [ChunkedIndex(*self.cache._get_chunk_index_from_index(idx)) for idx in _my_indices]
438465
return [self.cache[chnk_idx] for chnk_idx in _my_cache_indices]
439466
item = self.cache[index]
467+
440468
if hasattr(self, "transform"):
441469
if isinstance(self.transform, list):
442-
for transform_fn in self.transform:
443-
item = transform_fn(item)
470+
if not self.is_multisample:
471+
for transform_fn in self.transform:
472+
item = transform_fn(item)
473+
else:
474+
item = self.transform[sample_idx](item) # apply the specific transform for multisample
444475
else:
445476
item = self.transform(item)
446477

447478
return item
448479

449480
def __next__(self) -> Any:
481+
# print(self.worker_next_chunk_index, self.num_chunks)
450482
# check if we have reached the end of the dataset (i.e., all the chunks have been processed)
451483
if self.global_index >= self.stop_length:
452484
# global_index: total number of samples processed by the current worker across all chunks
@@ -476,7 +508,8 @@ def __next__(self) -> Any:
476508

477509
# `next_worker_chunks_index` is the index of the chunk that we will be working on now
478510
interval = self.worker_intervals[self.worker_next_chunk_index]
479-
current_indexes = np.arange(interval[1], interval[2])
511+
512+
current_indexes = np.arange(interval[1] * self.multisample_factor, interval[2] * self.multisample_factor)
480513

481514
assert self.shuffler is not None
482515
assert self.num_chunks is not None

tests/streaming/test_dataset.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1813,3 +1813,114 @@ def transform(self, x, *args, **kwargs):
18131813
# Verify that the transform is applied correctly
18141814
for i, item in enumerate(complete_data):
18151815
assert item == i * 2, f"Expected {i * 2}, got {item}"
1816+
1817+
1818+
def test_dataset_multisample(tmpdir):
1819+
"""Test if the dataset transform is applied correctly."""
1820+
# Create a simple dataset
1821+
# Create directories for cache and data
1822+
cache_dir = os.path.join(tmpdir, "cache_dir")
1823+
data_dir = os.path.join(tmpdir, "data_dir")
1824+
os.makedirs(cache_dir)
1825+
os.makedirs(data_dir)
1826+
1827+
# Create a dataset with 100 items, 20 items per chunk
1828+
cache = Cache(str(data_dir), chunk_size=20)
1829+
for i in range(100):
1830+
cache[i] = i
1831+
cache.done()
1832+
cache.merge()
1833+
1834+
# Define simple transform functions
1835+
def transform_fn_sq(x, *args, **kwargs):
1836+
"""A simple transform function that doubles the input."""
1837+
return x * 2
1838+
1839+
def transform_fn_add(x):
1840+
"""A simple transform function that adds 3 to the input."""
1841+
return x + 3
1842+
1843+
def transform_fn_identity(x):
1844+
"""A simple transform function that returns the input as is."""
1845+
return x
1846+
1847+
dataset = StreamingDataset(
1848+
data_dir,
1849+
cache_dir=str(cache_dir),
1850+
shuffle=False,
1851+
transform=[transform_fn_sq, transform_fn_add, transform_fn_identity],
1852+
is_multisample=True,
1853+
)
1854+
dataset_length = len(dataset)
1855+
assert dataset_length == 300
1856+
1857+
# ASSERT
1858+
# Verify that the transform functions are applied correctly
1859+
for i, item in enumerate(dataset):
1860+
assert item is not None
1861+
if i % 3 == 0:
1862+
assert item == (i // len(dataset.transform)) * 2, (
1863+
f"Expected {(i // len(dataset.transform)) * 2}, got {item}"
1864+
)
1865+
elif i % 3 == 1:
1866+
assert item == (i // len(dataset.transform)) + 3, (
1867+
f"Expected {(i // len(dataset.transform)) + 3}, got {item}"
1868+
)
1869+
else:
1870+
assert item == (i // len(dataset.transform)), f"Expected {(i // len(dataset.transform))}, got {item}"
1871+
1872+
1873+
def test_dataset_multisample_single_transform(tmpdir):
1874+
"""Test if the dataset transform is applied correctly."""
1875+
# Create a simple dataset
1876+
# Create directories for cache and data
1877+
cache_dir = os.path.join(tmpdir, "cache_dir")
1878+
data_dir = os.path.join(tmpdir, "data_dir")
1879+
os.makedirs(cache_dir)
1880+
os.makedirs(data_dir)
1881+
1882+
# Create a dataset with 100 items, 20 items per chunk
1883+
cache = Cache(str(data_dir), chunk_size=20)
1884+
for i in range(100):
1885+
cache[i] = i
1886+
cache.done()
1887+
cache.merge()
1888+
1889+
# Define simple transform functions
1890+
def transform_fn_sq(x, *args, **kwargs):
1891+
"""A simple transform function that doubles the input."""
1892+
return x * 2
1893+
1894+
dataset = StreamingDataset(
1895+
data_dir, cache_dir=str(cache_dir), shuffle=False, transform=transform_fn_sq, is_multisample=True
1896+
)
1897+
dataset_length = len(dataset)
1898+
assert dataset_length == 100
1899+
1900+
# ASSERT
1901+
# Verify that the transform function is applied correctly
1902+
for i, item in enumerate(dataset):
1903+
assert item is not None
1904+
assert item == (i * 2), f"Expected {(i * 2)}, got {item}"
1905+
1906+
1907+
def test_dataset_multisample_nonlist_transform_error(tmpdir):
1908+
"""Test if the dataset raises an error when is_multisample is True but transform is not a list."""
1909+
# Create a simple dataset
1910+
# Create directories for cache and data
1911+
cache_dir = os.path.join(tmpdir, "cache_dir")
1912+
data_dir = os.path.join(tmpdir, "data_dir")
1913+
os.makedirs(cache_dir)
1914+
os.makedirs(data_dir)
1915+
1916+
# Create a dataset with 100 items, 20 items per chunk
1917+
cache = Cache(str(data_dir), chunk_size=20)
1918+
for i in range(100):
1919+
cache[i] = i
1920+
cache.done()
1921+
cache.merge()
1922+
1923+
# ASSERT
1924+
# Verify that ValueError is raised when transform is not given
1925+
with pytest.raises(ValueError, match="When using `is_multisample=True`, `transform` must be a list of callables."):
1926+
StreamingDataset(data_dir, cache_dir=str(cache_dir), shuffle=False, is_multisample=True)

0 commit comments

Comments
 (0)