|
1 | 1 | import gc |
2 | 2 | import logging |
3 | 3 | import os |
| 4 | +import random |
4 | 5 | from pathlib import Path |
5 | 6 | from typing import Callable, Iterator, List, Optional, Tuple |
6 | 7 |
|
| 8 | +import numpy as np |
7 | 9 | import torch |
8 | 10 | import torch.distributed as dist |
9 | 11 | import torch.nn as nn |
10 | 12 | from torch.distributed.distributed_c10d import _get_default_group |
11 | 13 | from torch.optim import Optimizer |
12 | 14 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
13 | 15 | from torch.utils.data import DataLoader |
| 16 | +from torch.utils.data.distributed import DistributedSampler |
14 | 17 |
|
15 | 18 | from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO |
16 | 19 | from colossalai.checkpoint_io.utils import ( |
@@ -448,6 +451,57 @@ def control_device(self) -> bool: |
448 | 451 |
|
449 | 452 | def supported_devices(self) -> List[str]: |
450 | 453 | return ["cuda", "npu"] |
| 454 | + |
| 455 | + def prepare_dataloader( |
| 456 | + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs |
| 457 | + ): |
| 458 | + r""" |
| 459 | + Prepare a dataloader for distributed training. The dataloader will be wrapped by |
| 460 | + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. |
| 461 | +
|
| 462 | +
|
| 463 | + Args: |
| 464 | + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. |
| 465 | + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. |
| 466 | + seed (int, optional): Random worker seed for sampling, defaults to 1024. |
| 467 | + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. |
| 468 | + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size |
| 469 | + is not divisible by the batch size. If False and the size of dataset is not divisible by |
| 470 | + the batch size, then the last batch will be smaller, defaults to False. |
| 471 | + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. |
| 472 | + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. |
| 473 | + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in |
| 474 | + `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_. |
| 475 | +
|
| 476 | + Returns: |
| 477 | + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. |
| 478 | + """ |
| 479 | + _kwargs = kwargs.copy() |
| 480 | + zero_world_size = self.pg_mesh.size(ZERO_AXIS) |
| 481 | + extra_dp_world_size = self.pg_mesh.size(DP_AXIS) |
| 482 | + zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) |
| 483 | + extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) |
| 484 | + sampler = DistributedSampler( |
| 485 | + dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle |
| 486 | + ) |
| 487 | + |
| 488 | + # Deterministic dataloader |
| 489 | + def seed_worker(worker_id): |
| 490 | + worker_seed = seed |
| 491 | + np.random.seed(worker_seed) |
| 492 | + torch.manual_seed(worker_seed) |
| 493 | + random.seed(worker_seed) |
| 494 | + |
| 495 | + return DataLoader( |
| 496 | + dataset, |
| 497 | + batch_size=batch_size, |
| 498 | + sampler=sampler, |
| 499 | + worker_init_fn=seed_worker, |
| 500 | + drop_last=drop_last, |
| 501 | + pin_memory=pin_memory, |
| 502 | + num_workers=num_workers, |
| 503 | + **_kwargs, |
| 504 | + ) |
451 | 505 |
|
452 | 506 | def configure( |
453 | 507 | self, |
|
0 commit comments