Skip to content

DeepSpeed stage 3 and mixed precision cause an error  #10510

Open
@ktrapeznikov

Description

@ktrapeznikov

🐛 Bug

Using strategy="deepspeed_stage_3" and precision=16 causes an error

To Reproduce

import os
import torch
from torch.utils.data import DataLoader, Dataset
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return FusedAdam(self.layer.parameters(), lr=0.1)
        # return torch.optim.Adam(self.parameters(),lr = .1)

def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        logger=False,
        enable_checkpointing=False,
        gpus = 4,
        precision=16,
        strategy = "deepspeed_stage_3"
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

I get the following error:

Traceback (most recent call last):
  File "bug.py", line 69, in <module>
    run()
  File "bug.py", line 64, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
    self._call_and_handle_interrupt(
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1182, in _run
    self._pre_dispatch()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in _pre_dispatch
    self.accelerator.pre_dispatch(self)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 136, in pre_dispatch
    self.training_type_plugin.pre_dispatch()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 397, in pre_dispatch
    self.init_deepspeed()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 474, in init_deepspeed
    self._initialize_deepspeed_train(model)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 507, in _initialize_deepspeed_train
    model, deepspeed_optimizer = self._setup_model_and_optimizer(model, optimizer, scheduler)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/deepspeed.py", line 431, in _setup_model_and_optimizer
    deepspeed_engine, deepspeed_optimizer, _, _ = deepspeed.initialize(
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/__init__.py", line 131, in initialize
    engine = DeepSpeedEngine(args=args,
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 223, in __init__
    self._configure_optimizer(optimizer, model_parameters)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 905, in _configure_optimizer
    self.optimizer = self._configure_zero_optimizer(basic_optimizer)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1152, in _configure_zero_optimizer
    optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 905, in __init__
    self.create_reduce_and_remove_grad_hooks()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 1885, in create_reduce_and_remove_grad_hooks
    param.all_gather()
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 590, in all_gather
    return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 699, in _all_gather
    ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy)
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 947, in _allgather_params_coalesced
    h = dist._all_gather_base(allgather_params[param_idx],
  File "/home/kirill.trapeznikov/miniconda3/envs/semafor_nlg/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 2070, in _all_gather_base
    work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: output tensor must have the same type as input tensor

Expected behavior

it should work, right?

Environment

* CUDA:
        - GPU:
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
                - GeForce RTX 2080 Ti
        - available:         True
        - version:           11.3
* Packages:
        - numpy:             1.21.1
        - pyTorch_debug:     False
        - pyTorch_version:   1.10.0+cu113
        - pytorch-lightning: 1.5.1
        - tqdm:              4.62.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.11
        - version:           #1 SMP Wed Feb 3 15:06:38 UTC 2021
  
  • Any other relevant information: deepspeed 0.5.6

Additional context

cc @SeanNaren @awaelchli @rohitgr7

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions