32
32
from litdata .utilities .dataset_utilities import _should_replace_path , _try_create_cache_dir , subsample_streaming_dataset
33
33
from litdata .utilities .encryption import Encryption
34
34
from litdata .utilities .env import _DistributedEnv , _is_in_dataloader_worker , _WorkerEnv
35
- from litdata .utilities .shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion
35
+ from litdata .utilities .shuffle import (
36
+ _find_chunks_per_workers_on_which_to_skip_deletion ,
37
+ _map_node_worker_rank_to_chunk_indexes_to_not_delete ,
38
+ )
36
39
37
40
logger = Logger (__name__ )
38
41
@@ -120,8 +123,10 @@ def __init__(
120
123
self .shuffler : Optional [Shuffle ] = None
121
124
self .serializers = serializers
122
125
self ._state_dict : Optional [Dict [str , Any ]] = None
123
- self .num_workers : Optional [int ] = None
124
- self .batch_size : Optional [int ] = None
126
+ # Has slightly different meaning in the context of the dataset
127
+ # We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process)
128
+ self .num_workers : int = 1
129
+ self .batch_size : int = 1
125
130
self ._encryption = encryption
126
131
127
132
def set_shuffle (self , shuffle : bool ) -> None :
@@ -179,7 +184,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
179
184
return FullShuffle (cache , seed , drop_last ) if self .shuffle else NoShuffle (cache , seed , drop_last )
180
185
181
186
def __len__ (self ) -> int :
182
- return self .get_len (1 , 1 )
187
+ return self .get_len (self .num_workers , self .batch_size if self .batch_size else 1 )
188
+
189
+ def set_batch_size (self , batch_size : int ) -> None :
190
+ self .batch_size = batch_size
191
+
192
+ def set_num_workers (self , num_workers : int ) -> None :
193
+ self .num_workers = num_workers or 1
183
194
184
195
def get_len (self , num_workers : int , batch_size : int ) -> int :
185
196
self .num_workers = num_workers
@@ -205,35 +216,46 @@ def __iter__(self) -> "StreamingDataset":
205
216
state : Dict [str , Any ] = self ._state_dict
206
217
self .current_epoch = state ["current_epoch" ]
207
218
208
- chunks_per_replica , intervals_per_replica = self .shuffler .get_chunks_and_intervals_per_ranks (
209
- self .distributed_env , self .worker_env .world_size , self .batch_size or 1 , self .current_epoch
219
+ workers_chunks , workers_intervals = self .shuffler .get_chunks_and_intervals_per_workers (
220
+ self .distributed_env , self .worker_env .world_size , self .batch_size , self .current_epoch
210
221
)
211
- chunks_replica = chunks_per_replica [self .distributed_env .global_rank % self .distributed_env .world_size ]
212
- intervals_replica = intervals_per_replica [self .distributed_env .global_rank % self .distributed_env .world_size ]
222
+
223
+ worker_rank = self .distributed_env .global_rank * self .worker_env .world_size + self .worker_env .rank
224
+ self .worker_chunks = workers_chunks [worker_rank ]
225
+ self .worker_intervals = workers_intervals [worker_rank ]
226
+
227
+ # The max number of samples to return from `__next__` (in worker)
228
+ self .stop_length = sum (interval [2 ] - interval [1 ] for interval in self .worker_intervals )
213
229
214
230
# Handle restart
215
231
if self ._state_dict :
216
- self ._resume (chunks_replica , intervals_replica )
232
+ self ._resume (workers_chunks , workers_intervals )
217
233
else :
218
- # Find the chunks shared across multiple ranks.
219
- # For each shared chunk, find the rank to use the chunk last and prevent deletion
220
- # for the other ranks.
221
- chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion (
222
- self .worker_env .world_size , chunks_per_replica , intervals_per_replica
234
+ # Find the chunks shared across all workers of the current node.
235
+ # For each shared chunk, find the rank and worker to use the chunk last and prevent
236
+ # premature deletion for the other workers.
237
+ node_size = self .distributed_env .world_size // self .distributed_env .num_nodes
238
+ first_rank_this_node = (self .distributed_env .global_rank // node_size ) * node_size
239
+ num_workers_per_node = node_size * self .num_workers
240
+ worker_start = first_rank_this_node * num_workers_per_node
241
+ worker_end = worker_start + num_workers_per_node
242
+ local_rank = self .distributed_env .global_rank % node_size
243
+
244
+ chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion (
245
+ self .num_workers ,
246
+ self .batch_size ,
247
+ workers_chunks [worker_start :worker_end ],
248
+ workers_intervals [worker_start :worker_end ],
223
249
)
224
- if self .distributed_env .global_rank in chunks_indexes_skip_deletion :
225
- self .cache ._reader .config .skip_chunk_indexes_deletion = chunks_indexes_skip_deletion [
226
- self .distributed_env .global_rank
227
- ]
228
-
229
- workers_chunks , workers_intervals = _associate_chunks_to_workers (
230
- self .worker_env ,
231
- chunks_per_replica [self .distributed_env .global_rank ],
232
- intervals_per_replica [self .distributed_env .global_rank ],
250
+ worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete (
251
+ chunks_indexes_skip_deletion
233
252
)
234
253
235
- self .worker_chunks = workers_chunks [self .worker_env .rank ]
236
- self .worker_intervals = workers_intervals [self .worker_env .rank ]
254
+ worker_rank_local_node = local_rank * self .num_workers + self .worker_env .rank
255
+ if worker_rank_local_node in worker_node_rank_to_chunk_indexes :
256
+ self .cache ._reader .config .skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes [
257
+ worker_rank_local_node
258
+ ]
237
259
238
260
self .num_chunks = len (self .worker_chunks )
239
261
self .current_indexes = []
@@ -246,7 +268,7 @@ def __iter__(self) -> "StreamingDataset":
246
268
247
269
return self
248
270
249
- def _resume (self , chunks_replica : List [int ], intervals_replica : List [Any ]) -> None :
271
+ def _resume (self , workers_chunks : List [List [ int ]], workers_intervals : List [Any ]) -> None :
250
272
assert self ._state_dict
251
273
assert self .worker_env
252
274
assert self .shuffler
@@ -259,17 +281,22 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
259
281
# TODO: Implement elastic sampling where the number of workers, ranks can change.
260
282
num_samples_yielded = self ._state_dict ["num_samples_yielded" ]
261
283
284
+ worker_start = self .distributed_env .global_rank * num_workers
285
+ worker_end = worker_start + num_workers
286
+
262
287
# replay sampling from each worker / chunks using the batch size
263
- workers_chunks , workers_intervals = _associate_chunks_to_workers (
264
- self .worker_env , chunks_replica , intervals_replica
265
- )
266
288
indexes = _replay_sampling (num_samples_yielded , batch_size , num_workers )
267
- chunks_index , indexes = _replay_chunks_sampling (workers_intervals , indexes )
289
+ chunks_index , indexes = _replay_chunks_sampling (
290
+ workers_intervals = {i : workers_intervals [i ] for i in range (worker_start , worker_end )},
291
+ indexes = indexes ,
292
+ )
268
293
269
294
# select the chunks and intervals associated to this worker
270
- worker_rank = self .worker_env .rank
295
+ worker_rank = self .distributed_env .global_rank * self .worker_env .world_size + self .worker_env .rank
296
+ worker_local_rank = self .worker_env .rank
297
+
271
298
self .num_chunks = len (workers_intervals [worker_rank ])
272
- self .chunk_index = chunks_index [worker_rank ]
299
+ self .chunk_index = chunks_index [worker_local_rank ]
273
300
self .worker_chunks = workers_chunks [worker_rank ]
274
301
self .worker_intervals = workers_intervals [worker_rank ]
275
302
@@ -281,10 +308,10 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
281
308
current_indexes = self .shuffler (current_indexes , self .num_chunks , self .current_epoch , self .chunk_index )
282
309
283
310
# skip any indexes already consumed
284
- current_indexes = current_indexes [indexes [worker_rank ] :]
311
+ current_indexes = current_indexes [indexes [worker_local_rank ] :]
285
312
self .current_indexes = current_indexes
286
313
287
- self .global_index = num_samples_yielded
314
+ self .global_index = indexes [ worker_local_rank ]
288
315
289
316
# bump the chunk_index
290
317
self .chunk_index += 1
@@ -305,7 +332,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
305
332
306
333
def __next__ (self ) -> Any :
307
334
# Prevent to create more batch on a given process
308
- if self .global_index >= len ( self ) :
335
+ if self .global_index >= self . stop_length :
309
336
self .current_epoch += 1
310
337
raise StopIteration
311
338
@@ -454,8 +481,8 @@ def reset(self) -> None:
454
481
"random_state" : None ,
455
482
"shuffler" : None ,
456
483
"_state_dict" : None ,
457
- "num_workers" : None ,
458
- "batch_size" : None ,
484
+ "num_workers" : 1 ,
485
+ "batch_size" : 1 ,
459
486
}
460
487
461
488
for prop , value in default_properties .items ():
@@ -470,28 +497,6 @@ def is_integer(value: str) -> bool:
470
497
return False
471
498
472
499
473
- def _associate_chunks_to_workers (
474
- worker_env : _WorkerEnv , chunks_replica : List [int ], intervals_replica : List [Any ]
475
- ) -> Any :
476
- workers_chunks = {}
477
- workers_intervals = {}
478
-
479
- for worker_idx in range (worker_env .world_size ):
480
- worker_chunks = []
481
- worker_intervals = []
482
- for i , (chunk_index , chunk_interval ) in enumerate (zip (chunks_replica , intervals_replica )):
483
- if i % worker_env .world_size != worker_idx :
484
- continue
485
-
486
- worker_chunks .append (chunk_index )
487
- worker_intervals .append (chunk_interval )
488
-
489
- workers_chunks [worker_idx ] = worker_chunks
490
- workers_intervals [worker_idx ] = worker_intervals
491
-
492
- return workers_chunks , workers_intervals
493
-
494
-
495
500
def _replay_sampling (num_samples_yielded : int , batch_size : int , num_workers : int ) -> Dict [int , int ]:
496
501
"""This function replays the sampling from the dataloader."""
497
502
divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size )
0 commit comments