Skip to content

Commit c58b673

Browse files
awaelchlitchaton
andauthored
Fix uneven batches in distributed dataloading (#237)
This was a big one! Co-authored-by: tchaton <thomas.chaton.ai@gmail.com>
1 parent 7bad26f commit c58b673

File tree

10 files changed

+502
-308
lines changed

10 files changed

+502
-308
lines changed

src/litdata/processing/data_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
255255
output_filepath = remove_uuid_from_filename(output_filepath) # remove unique id from checkpoints
256256

257257
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
258-
shutil.move(local_filepath, output_filepath)
258+
shutil.copy(local_filepath, output_filepath)
259259
else:
260260
raise ValueError(f"The provided {output_dir.path} isn't supported.")
261261

src/litdata/streaming/combined.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ def set_shuffle(self, shuffle: bool) -> None:
118118
for dataset in self._datasets:
119119
dataset.set_shuffle(shuffle)
120120

121+
def set_batch_size(self, batch_size: int) -> None:
122+
"""Set the current batch size to the datasets."""
123+
self.batch_size = batch_size
124+
for dataset in self._datasets:
125+
dataset.set_batch_size(batch_size)
126+
127+
def set_num_workers(self, num_workers: int) -> None:
128+
"""Set the current number of workers to the datasets."""
129+
for dataset in self._datasets:
130+
dataset.set_num_workers(num_workers)
131+
121132
def set_drop_last(self, drop_last: bool) -> None:
122133
"""Set the current drop_last to the datasets."""
123134
for dataset in self._datasets:

src/litdata/streaming/dataloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def __init__(
555555
profile_dir: Optional[str] = None,
556556
prefetch_factor: Optional[int] = None,
557557
shuffle: Optional[bool] = None,
558-
drop_last: Optional[bool] = False,
558+
drop_last: Optional[bool] = None,
559559
collate_fn: Optional[Callable] = None,
560560
**kwargs: Any,
561561
) -> None: # pyright: ignore
@@ -571,6 +571,9 @@ def __init__(
571571
if drop_last is not None:
572572
dataset.set_drop_last(drop_last)
573573

574+
dataset.set_batch_size(batch_size)
575+
dataset.set_num_workers(num_workers)
576+
574577
shuffle = None
575578

576579
if profile_batches and not _VIZ_TRACKER_AVAILABLE:

src/litdata/streaming/dataset.py

Lines changed: 64 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
from litdata.utilities.dataset_utilities import _should_replace_path, _try_create_cache_dir, subsample_streaming_dataset
3333
from litdata.utilities.encryption import Encryption
3434
from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv
35-
from litdata.utilities.shuffle import _find_chunks_per_ranks_on_which_to_skip_deletion
35+
from litdata.utilities.shuffle import (
36+
_find_chunks_per_workers_on_which_to_skip_deletion,
37+
_map_node_worker_rank_to_chunk_indexes_to_not_delete,
38+
)
3639

3740
logger = Logger(__name__)
3841

@@ -120,8 +123,10 @@ def __init__(
120123
self.shuffler: Optional[Shuffle] = None
121124
self.serializers = serializers
122125
self._state_dict: Optional[Dict[str, Any]] = None
123-
self.num_workers: Optional[int] = None
124-
self.batch_size: Optional[int] = None
126+
# Has slightly different meaning in the context of the dataset
127+
# We consider `num_workers = 0` from `torch.utils.DataLoader` still as 1 worker (the main process)
128+
self.num_workers: int = 1
129+
self.batch_size: int = 1
125130
self._encryption = encryption
126131

127132
def set_shuffle(self, shuffle: bool) -> None:
@@ -179,7 +184,13 @@ def _create_shuffler(self, cache: Cache) -> Shuffle:
179184
return FullShuffle(cache, seed, drop_last) if self.shuffle else NoShuffle(cache, seed, drop_last)
180185

181186
def __len__(self) -> int:
182-
return self.get_len(1, 1)
187+
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
188+
189+
def set_batch_size(self, batch_size: int) -> None:
190+
self.batch_size = batch_size
191+
192+
def set_num_workers(self, num_workers: int) -> None:
193+
self.num_workers = num_workers or 1
183194

184195
def get_len(self, num_workers: int, batch_size: int) -> int:
185196
self.num_workers = num_workers
@@ -205,35 +216,46 @@ def __iter__(self) -> "StreamingDataset":
205216
state: Dict[str, Any] = self._state_dict
206217
self.current_epoch = state["current_epoch"]
207218

208-
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
209-
self.distributed_env, self.worker_env.world_size, self.batch_size or 1, self.current_epoch
219+
workers_chunks, workers_intervals = self.shuffler.get_chunks_and_intervals_per_workers(
220+
self.distributed_env, self.worker_env.world_size, self.batch_size, self.current_epoch
210221
)
211-
chunks_replica = chunks_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
212-
intervals_replica = intervals_per_replica[self.distributed_env.global_rank % self.distributed_env.world_size]
222+
223+
worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
224+
self.worker_chunks = workers_chunks[worker_rank]
225+
self.worker_intervals = workers_intervals[worker_rank]
226+
227+
# The max number of samples to return from `__next__` (in worker)
228+
self.stop_length = sum(interval[2] - interval[1] for interval in self.worker_intervals)
213229

214230
# Handle restart
215231
if self._state_dict:
216-
self._resume(chunks_replica, intervals_replica)
232+
self._resume(workers_chunks, workers_intervals)
217233
else:
218-
# Find the chunks shared across multiple ranks.
219-
# For each shared chunk, find the rank to use the chunk last and prevent deletion
220-
# for the other ranks.
221-
chunks_indexes_skip_deletion = _find_chunks_per_ranks_on_which_to_skip_deletion(
222-
self.worker_env.world_size, chunks_per_replica, intervals_per_replica
234+
# Find the chunks shared across all workers of the current node.
235+
# For each shared chunk, find the rank and worker to use the chunk last and prevent
236+
# premature deletion for the other workers.
237+
node_size = self.distributed_env.world_size // self.distributed_env.num_nodes
238+
first_rank_this_node = (self.distributed_env.global_rank // node_size) * node_size
239+
num_workers_per_node = node_size * self.num_workers
240+
worker_start = first_rank_this_node * num_workers_per_node
241+
worker_end = worker_start + num_workers_per_node
242+
local_rank = self.distributed_env.global_rank % node_size
243+
244+
chunks_indexes_skip_deletion = _find_chunks_per_workers_on_which_to_skip_deletion(
245+
self.num_workers,
246+
self.batch_size,
247+
workers_chunks[worker_start:worker_end],
248+
workers_intervals[worker_start:worker_end],
223249
)
224-
if self.distributed_env.global_rank in chunks_indexes_skip_deletion:
225-
self.cache._reader.config.skip_chunk_indexes_deletion = chunks_indexes_skip_deletion[
226-
self.distributed_env.global_rank
227-
]
228-
229-
workers_chunks, workers_intervals = _associate_chunks_to_workers(
230-
self.worker_env,
231-
chunks_per_replica[self.distributed_env.global_rank],
232-
intervals_per_replica[self.distributed_env.global_rank],
250+
worker_node_rank_to_chunk_indexes = _map_node_worker_rank_to_chunk_indexes_to_not_delete(
251+
chunks_indexes_skip_deletion
233252
)
234253

235-
self.worker_chunks = workers_chunks[self.worker_env.rank]
236-
self.worker_intervals = workers_intervals[self.worker_env.rank]
254+
worker_rank_local_node = local_rank * self.num_workers + self.worker_env.rank
255+
if worker_rank_local_node in worker_node_rank_to_chunk_indexes:
256+
self.cache._reader.config.skip_chunk_indexes_deletion = worker_node_rank_to_chunk_indexes[
257+
worker_rank_local_node
258+
]
237259

238260
self.num_chunks = len(self.worker_chunks)
239261
self.current_indexes = []
@@ -246,7 +268,7 @@ def __iter__(self) -> "StreamingDataset":
246268

247269
return self
248270

249-
def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> None:
271+
def _resume(self, workers_chunks: List[List[int]], workers_intervals: List[Any]) -> None:
250272
assert self._state_dict
251273
assert self.worker_env
252274
assert self.shuffler
@@ -259,17 +281,22 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
259281
# TODO: Implement elastic sampling where the number of workers, ranks can change.
260282
num_samples_yielded = self._state_dict["num_samples_yielded"]
261283

284+
worker_start = self.distributed_env.global_rank * num_workers
285+
worker_end = worker_start + num_workers
286+
262287
# replay sampling from each worker / chunks using the batch size
263-
workers_chunks, workers_intervals = _associate_chunks_to_workers(
264-
self.worker_env, chunks_replica, intervals_replica
265-
)
266288
indexes = _replay_sampling(num_samples_yielded, batch_size, num_workers)
267-
chunks_index, indexes = _replay_chunks_sampling(workers_intervals, indexes)
289+
chunks_index, indexes = _replay_chunks_sampling(
290+
workers_intervals={i: workers_intervals[i] for i in range(worker_start, worker_end)},
291+
indexes=indexes,
292+
)
268293

269294
# select the chunks and intervals associated to this worker
270-
worker_rank = self.worker_env.rank
295+
worker_rank = self.distributed_env.global_rank * self.worker_env.world_size + self.worker_env.rank
296+
worker_local_rank = self.worker_env.rank
297+
271298
self.num_chunks = len(workers_intervals[worker_rank])
272-
self.chunk_index = chunks_index[worker_rank]
299+
self.chunk_index = chunks_index[worker_local_rank]
273300
self.worker_chunks = workers_chunks[worker_rank]
274301
self.worker_intervals = workers_intervals[worker_rank]
275302

@@ -281,10 +308,10 @@ def _resume(self, chunks_replica: List[int], intervals_replica: List[Any]) -> No
281308
current_indexes = self.shuffler(current_indexes, self.num_chunks, self.current_epoch, self.chunk_index)
282309

283310
# skip any indexes already consumed
284-
current_indexes = current_indexes[indexes[worker_rank] :]
311+
current_indexes = current_indexes[indexes[worker_local_rank] :]
285312
self.current_indexes = current_indexes
286313

287-
self.global_index = num_samples_yielded
314+
self.global_index = indexes[worker_local_rank]
288315

289316
# bump the chunk_index
290317
self.chunk_index += 1
@@ -305,7 +332,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:
305332

306333
def __next__(self) -> Any:
307334
# Prevent to create more batch on a given process
308-
if self.global_index >= len(self):
335+
if self.global_index >= self.stop_length:
309336
self.current_epoch += 1
310337
raise StopIteration
311338

@@ -454,8 +481,8 @@ def reset(self) -> None:
454481
"random_state": None,
455482
"shuffler": None,
456483
"_state_dict": None,
457-
"num_workers": None,
458-
"batch_size": None,
484+
"num_workers": 1,
485+
"batch_size": 1,
459486
}
460487

461488
for prop, value in default_properties.items():
@@ -470,28 +497,6 @@ def is_integer(value: str) -> bool:
470497
return False
471498

472499

473-
def _associate_chunks_to_workers(
474-
worker_env: _WorkerEnv, chunks_replica: List[int], intervals_replica: List[Any]
475-
) -> Any:
476-
workers_chunks = {}
477-
workers_intervals = {}
478-
479-
for worker_idx in range(worker_env.world_size):
480-
worker_chunks = []
481-
worker_intervals = []
482-
for i, (chunk_index, chunk_interval) in enumerate(zip(chunks_replica, intervals_replica)):
483-
if i % worker_env.world_size != worker_idx:
484-
continue
485-
486-
worker_chunks.append(chunk_index)
487-
worker_intervals.append(chunk_interval)
488-
489-
workers_chunks[worker_idx] = worker_chunks
490-
workers_intervals[worker_idx] = worker_intervals
491-
492-
return workers_chunks, workers_intervals
493-
494-
495500
def _replay_sampling(num_samples_yielded: int, batch_size: int, num_workers: int) -> Dict[int, int]:
496501
"""This function replays the sampling from the dataloader."""
497502
divisible_num_batches_yielded = num_samples_yielded // (num_workers * batch_size)

src/litdata/streaming/shuffle.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
from litdata.streaming import Cache
2121
from litdata.utilities.env import _DistributedEnv
22-
from litdata.utilities.shuffle import _associate_chunks_and_internals_to_ranks, _intra_node_chunk_shuffle
22+
from litdata.utilities.shuffle import (
23+
_associate_chunks_and_intervals_to_workers,
24+
_intra_node_chunk_shuffle,
25+
)
2326

2427

2528
class Shuffle(ABC):
@@ -32,23 +35,19 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):
3235

3336
@lru_cache(maxsize=10)
3437
def get_len(self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int) -> int:
35-
_, intervals_per_ranks = self.get_chunks_and_intervals_per_ranks(
38+
_, workers_intervals = self.get_chunks_and_intervals_per_workers(
3639
distributed_env, num_workers, batch_size, current_epoch
3740
)
38-
39-
if self.drop_last:
40-
items_per_process = [
41-
sum((interval[2] - interval[1]) for interval in intervals) for intervals in intervals_per_ranks
42-
]
43-
# Validate each processes gets the exact number of elements
44-
if len(items_per_process) > 1:
45-
assert all(items_per_process[0] == items_to_process for items_to_process in items_per_process[:1])
46-
return items_per_process[0]
47-
48-
return sum((interval[2] - interval[1]) for interval in intervals_per_ranks[distributed_env.global_rank])
41+
worker_start = distributed_env.global_rank * num_workers
42+
worker_end = worker_start + num_workers
43+
return sum(
44+
(interval[2] - interval[1])
45+
for intervals in workers_intervals[worker_start:worker_end]
46+
for interval in intervals
47+
)
4948

5049
@abstractmethod
51-
def get_chunks_and_intervals_per_ranks(
50+
def get_chunks_and_intervals_per_workers(
5251
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
5352
) -> Any:
5453
pass
@@ -63,19 +62,18 @@ class NoShuffle(Shuffle):
6362
is True."""
6463

6564
@lru_cache(maxsize=10)
66-
def get_chunks_and_intervals_per_ranks(
65+
def get_chunks_and_intervals_per_workers(
6766
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
6867
) -> Any:
6968
# 1. Get the intervals
7069
chunk_intervals = self.cache.get_chunk_intervals()
7170
indexes = range(len(chunk_intervals))
7271

7372
# 2. Compute the items budget of each rank
74-
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
73+
workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
7574
distributed_env, indexes, chunk_intervals, self.drop_last, num_workers, batch_size
7675
)
77-
78-
return chunks_per_ranks, intervals_per_ranks
76+
return workers_chunks, workers_intervals
7977

8078
def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
8179
return array.tolist()
@@ -100,7 +98,7 @@ class FullShuffle(Shuffle):
10098
"""
10199

102100
@lru_cache(maxsize=10)
103-
def get_chunks_and_intervals_per_ranks(
101+
def get_chunks_and_intervals_per_workers(
104102
self, distributed_env: _DistributedEnv, num_workers: int, batch_size: int, current_epoch: int
105103
) -> Any:
106104
# 1. Get the intervals
@@ -120,24 +118,24 @@ def get_chunks_and_intervals_per_ranks(
120118
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()
121119

122120
# 3. Compute the items budget of each rank
123-
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
121+
workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
124122
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
125123
)
126124

127125
# For the first epoch, no need of further shuffling
128126
if current_epoch == 1 or distributed_env.num_nodes == 1:
129-
return chunks_per_ranks, intervals_per_ranks
127+
return workers_chunks, workers_intervals
130128

131129
# Perform shuffle within the nodes to avoid cache miss.
132130
# Note: It is possible for the overlapping chunks to change due to the changing order.
133-
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, chunks_per_ranks, self.seed, current_epoch)
131+
shuffled_indexes = _intra_node_chunk_shuffle(distributed_env, workers_chunks, self.seed, current_epoch)
134132
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes].tolist()
135133

136-
chunks_per_ranks, intervals_per_ranks = _associate_chunks_and_internals_to_ranks(
134+
workers_chunks, workers_intervals = _associate_chunks_and_intervals_to_workers(
137135
distributed_env, shuffled_indexes, shuffled_chunk_intervals, self.drop_last, num_workers, batch_size
138136
)
139137

140-
return chunks_per_ranks, intervals_per_ranks
138+
return workers_chunks, workers_intervals
141139

142140
def __call__(self, array: np.ndarray, num_chunks: int, current_epoch: int, chunk_index: int) -> List[int]:
143141
return np.random.RandomState([self.seed, num_chunks * current_epoch, chunk_index]).permutation(array).tolist()

0 commit comments

Comments
 (0)