Skip to content

[BUG] Crash with a minimal ZeRO stage 3 NVMe checkpointing example #4565

Closed
@eisene

Description

Describe the bug

Simplest possible training code with ZeRO stage 3 with NVMe offload for the optimizer crashes on model.step() with the error

  File "/home/eeisenst/workspace/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2002, in unscale_and_clip_grads
    self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)

To Reproduce

import os
import deepspeed
import deepspeed.comm as dist
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
import torch


class SimpleModel(torch.nn.Module):

    def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
        super(SimpleModel, self).__init__()
        self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)])
        if empty_grad:
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
        self.empty_grad = empty_grad

    def forward(self, x, y):
        if len(self.linears) == 1:
            x = self.linears[0](x)
        else:
            for i, l in enumerate(self.linears):
                x = self.linears[i // 2](x) + l(x)
        return self.cross_entropy_loss(x, y)


def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
    train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    return train_dataset


def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
    batch_size = model.train_micro_batch_size_per_gpu()
    train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    return train_loader

tmpdir = "/home/eeisenst/workspace/temp/temp"    # CHANGE THIS TO SOMETHING CONVENIENT
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")

torch.manual_seed(12345)

config_dict = {
    "train_micro_batch_size_per_gpu": 1,
    "steps_per_print": 1,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 1e-6
        }
    },
    "fp16": {
        "enabled": True,
        "initial_scale_power": 2
    },
    "zero_optimization": {
        "stage": 3,
        "offload_param": {
            "device": OffloadDeviceEnum.cpu
            # "device": OffloadDeviceEnum.nvme,
            # "nvme_path": str(zero_dir)
        },
        "offload_optimizer": {
            # "device": OffloadDeviceEnum.cpu
            "device": OffloadDeviceEnum.nvme,
            "nvme_path": str(zero_dir)
        },
        "sub_group_size": 100,
        "stage3_max_live_parameters": 100,
        "stage3_param_persistence_threshold": 0,
    },
    "aio": {
        "block_size": 1048576       # Minimum AIO bytes, anything smaller than this will not be offloaded
    }
}

hidden_dim, nlayers = 2048, 5
with deepspeed.zero.Init(config_dict_or_path=config_dict):
    model = SimpleModel(hidden_dim, nlayers=nlayers, empty_grad=False)

model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
                                total_samples=10,
                                hidden_dim=hidden_dim,
                                device=model.device,
                                dtype=torch.float16)
dist.barrier()
for n, batch in enumerate(data_loader):
    loss = model(batch[0], batch[1])
    model.backward(loss)
    model.step()

Expected behavior
This script should exit with no error.

ds_report output

--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [YES] ...... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
 [WARNING]  using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/eeisenst/miniconda3/envs/deepspeed-test/lib/python3.11/site-packages/torch']
torch version .................... 2.0.1
deepspeed install path ........... ['/home/eeisenst/workspace/DeepSpeed/deepspeed']
deepspeed info ................... 0.10.4+6c6a1ec0, 6c6a1ec0, nvme_ckpt
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8
shared memory (/dev/shm) size .... 15.57 GB

System info (please complete the following information):

  • OS: Fedora 37
  • GPU count and types: 1x 3080 Ti on test machine, but it doesn't seem to matter
  • Python version: 3.11.5

Environment:

mamba install -c conda-forge pip python pytest pytorch gcc=11 libaio rust cmake

Build command:

CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1  DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v

Launcher context
deepspeed launcher

Docker context
No docker.

Additional context

It seems that the problem is being caused by the following two lines 1334-1335 in deepspeed/runtime/zero/stage3.py in DeepSpeedZeroOptimizer_Stage3.partition_grads:

            # offload the gradient partition if applicable
            if self.offload_optimizer:
                i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
                offload_fp32_gradients = {}                   # THIS IS THE BUG???
                offload_fp32_offsets = {}                     # THIS IS THE BUG???

This resets the dictionary of offloaded gradients so that, later in the same function, lines 1357-1361 do nothing:

        if self.offload_optimizer and self.swap_optimizer:
            for i in offload_fp32_gradients.keys():
                self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i],
                                                          gradient_offsets=offload_fp32_offsets[i],
                                                          gradient_tensors=offload_fp32_gradients[i])

Commenting the lines marked BUG causes the script to work as expected.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions