Skip to content

Commit 21aa5de

Browse files
authored
[gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150)
* fix aaa fix fix fix * fix * fix * test ci * fix ci fix
1 parent b397104 commit 21aa5de

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import gc
22
import logging
33
import os
4+
import random
45
from pathlib import Path
56
from typing import Callable, Iterator, List, Optional, Tuple
67

8+
import numpy as np
79
import torch
810
import torch.distributed as dist
911
import torch.nn as nn
1012
from torch.distributed.distributed_c10d import _get_default_group
1113
from torch.optim import Optimizer
1214
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
1315
from torch.utils.data import DataLoader
16+
from torch.utils.data.distributed import DistributedSampler
1417

1518
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
1619
from colossalai.checkpoint_io.utils import (
@@ -448,6 +451,57 @@ def control_device(self) -> bool:
448451

449452
def supported_devices(self) -> List[str]:
450453
return ["cuda", "npu"]
454+
455+
def prepare_dataloader(
456+
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
457+
):
458+
r"""
459+
Prepare a dataloader for distributed training. The dataloader will be wrapped by
460+
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
461+
462+
463+
Args:
464+
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
465+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
466+
seed (int, optional): Random worker seed for sampling, defaults to 1024.
467+
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
468+
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
469+
is not divisible by the batch size. If False and the size of dataset is not divisible by
470+
the batch size, then the last batch will be smaller, defaults to False.
471+
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
472+
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
473+
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
474+
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
475+
476+
Returns:
477+
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
478+
"""
479+
_kwargs = kwargs.copy()
480+
zero_world_size = self.pg_mesh.size(ZERO_AXIS)
481+
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
482+
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
483+
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
484+
sampler = DistributedSampler(
485+
dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle
486+
)
487+
488+
# Deterministic dataloader
489+
def seed_worker(worker_id):
490+
worker_seed = seed
491+
np.random.seed(worker_seed)
492+
torch.manual_seed(worker_seed)
493+
random.seed(worker_seed)
494+
495+
return DataLoader(
496+
dataset,
497+
batch_size=batch_size,
498+
sampler=sampler,
499+
worker_init_fn=seed_worker,
500+
drop_last=drop_last,
501+
pin_memory=pin_memory,
502+
num_workers=num_workers,
503+
**_kwargs,
504+
)
451505

452506
def configure(
453507
self,

examples/language/llama2/benchmark.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def main():
7272
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
7373
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
7474
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
75+
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
7576
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
7677
parser.add_argument("--mbs", type=int, default=1)
7778
parser.add_argument("--zero", type=int, default=0)
@@ -93,9 +94,11 @@ def empty_init():
9394
shard_param_frac=args.shard_param_frac,
9495
offload_optim_frac=args.offload_optim_frac,
9596
offload_param_frac=args.offload_param_frac,
97+
tp_size=args.tp,
98+
extra_dp_size=args.extra_dp,
9699
)
97100
elif args.plugin == "gemini_auto":
98-
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio)
101+
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp)
99102
elif args.plugin == "fsdp":
100103
if use_empty_init:
101104
plugin = TorchFSDPPlugin(

tests/kit/model_zoo/transformers/gptj.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def data_gen_for_sequence_classification():
6161

6262
config = transformers.GPTJConfig(
6363
n_layer=2,
64-
n_head=16,
64+
n_head=4,
6565
vocab_size=50258,
6666
attn_pdrop=0,
6767
embd_pdrop=0,

0 commit comments

Comments
 (0)