@@ -62,6 +62,7 @@ def __init__(
6262 index_path : Optional [str ] = None ,
6363 force_override_state_dict : bool = False ,
6464 transform : Optional [Union [Callable , list [Callable ]]] = None ,
65+ is_multisample : bool = False ,
6566 ) -> None :
6667 """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.
6768
@@ -89,6 +90,7 @@ def __init__(
8990 If `index_path` is a full file path, it will use that directly.
9091 force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
9192 transform: Optional transformation function or list of functions to apply to each item in the dataset.
93+ is_multisample: If True, each index access returns multiple samples transformed by the list of functions.
9294 """
9395 _check_version_and_prompt_upgrade (__version__ )
9496
@@ -209,6 +211,9 @@ def __init__(
209211 raise ValueError (f"Transform should be a callable. Found { t } " )
210212 self .transform = transform
211213 self ._on_demand_bytes = True # true by default, when iterating, turn this off to store the chunks in the cache
214+ self .is_multisample = is_multisample
215+ if self .is_multisample and not transform :
216+ raise ValueError ("When using `is_multisample=True`, `transform` must be a list of callables." )
212217
213218 @property
214219 def on_demand_bytes (self ) -> bool :
@@ -282,7 +287,8 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
282287 return FullShuffle (cache , seed , drop_last ) if self .shuffle else NoShuffle (cache , seed , drop_last )
283288
284289 def __len__ (self ) -> int :
285- return self .get_len (self .num_workers , self .batch_size if self .batch_size else 1 )
290+ original_len = self .get_len (self .num_workers , self .batch_size if self .batch_size else 1 )
291+ return original_len if not self .is_multisample else original_len * len (self .transform )
286292
287293 def set_batch_size (self , batch_size : int ) -> None :
288294 self .batch_size = batch_size
@@ -323,8 +329,13 @@ def __iter__(self) -> "StreamingDataset":
323329 self .worker_chunks = workers_chunks [worker_rank ]
324330 self .worker_intervals = workers_intervals [worker_rank ]
325331
332+ # multiply the interval by the multisample factor if multisampling is enabled
333+ self .multisample_factor = len (self .transform ) if self .is_multisample else 1
334+
326335 # The max number of samples to return from `__next__` (in worker)
327- self .stop_length = sum (interval [2 ] - interval [1 ] for interval in self .worker_intervals )
336+ self .stop_length = (
337+ sum (interval [2 ] - interval [1 ] for interval in self .worker_intervals ) * self .multisample_factor
338+ )
328339
329340 # Handle restart
330341 if self ._state_dict :
@@ -407,7 +418,8 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
407418
408419 # replay the indexes for the current chunks
409420 interval = self .worker_intervals [self .worker_next_chunk_index ]
410- current_indexes = np .arange (interval [1 ], interval [2 ])
421+ # multiply the interval by the multisample factor if multisampling is enabled
422+ current_indexes = np .arange (interval [1 ] * self .multisample_factor , interval [2 ] * self .multisample_factor )
411423
412424 # re-shuffle the indexes
413425 current_indexes = self .shuffler (
@@ -424,6 +436,21 @@ def _resume(self, workers_chunks: list[list[int]], workers_intervals: list[Any])
424436 self .worker_next_chunk_index += 1
425437
426438 def __getitem__ (self , index : Union [ChunkedIndex , int , slice ]) -> Any :
439+ # Deflate index for multisample case
440+ if self .is_multisample :
441+ if not self .transform :
442+ raise ValueError ("When using `is_multisample=True`, `transform` must be a list of callables." )
443+ if not all (callable (fn ) for fn in self .transform ):
444+ raise ValueError ("All elements in `transform` must be callable when using `is_multisample=True`." )
445+ if isinstance (index , int ):
446+ sample_idx = index % len (self .transform )
447+ index = index // len (self .transform )
448+ elif isinstance (index , ChunkedIndex ):
449+ sample_idx = index .index % len (self .transform )
450+ index .index = index .index // len (self .transform )
451+ else :
452+ raise ValueError ("Slices are not supported when using `is_multisample=True`." )
453+
427454 if self .cache is None :
428455 self .worker_env = _WorkerEnv .detect ()
429456 self .cache = self ._create_cache (worker_env = self .worker_env )
@@ -437,16 +464,21 @@ def __getitem__(self, index: Union[ChunkedIndex, int, slice]) -> Any:
437464 _my_cache_indices = [ChunkedIndex (* self .cache ._get_chunk_index_from_index (idx )) for idx in _my_indices ]
438465 return [self .cache [chnk_idx ] for chnk_idx in _my_cache_indices ]
439466 item = self .cache [index ]
467+
440468 if hasattr (self , "transform" ):
441469 if isinstance (self .transform , list ):
442- for transform_fn in self .transform :
443- item = transform_fn (item )
470+ if not self .is_multisample :
471+ for transform_fn in self .transform :
472+ item = transform_fn (item )
473+ else :
474+ item = self .transform [sample_idx ](item ) # apply the specific transform for multisample
444475 else :
445476 item = self .transform (item )
446477
447478 return item
448479
449480 def __next__ (self ) -> Any :
481+ # print(self.worker_next_chunk_index, self.num_chunks)
450482 # check if we have reached the end of the dataset (i.e., all the chunks have been processed)
451483 if self .global_index >= self .stop_length :
452484 # global_index: total number of samples processed by the current worker across all chunks
@@ -476,7 +508,8 @@ def __next__(self) -> Any:
476508
477509 # `next_worker_chunks_index` is the index of the chunk that we will be working on now
478510 interval = self .worker_intervals [self .worker_next_chunk_index ]
479- current_indexes = np .arange (interval [1 ], interval [2 ])
511+
512+ current_indexes = np .arange (interval [1 ] * self .multisample_factor , interval [2 ] * self .multisample_factor )
480513
481514 assert self .shuffler is not None
482515 assert self .num_chunks is not None
0 commit comments