[BUG] Crash with a minimal ZeRO stage 3 NVMe checkpointing example #4565
Closed
Description
opened on Oct 25, 2023
Describe the bug
Simplest possible training code with ZeRO stage 3 with NVMe offload for the optimizer crashes on model.step()
with the error
File "/home/eeisenst/workspace/DeepSpeed/deepspeed/runtime/zero/stage3.py", line 2002, in unscale_and_clip_grads
self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale)
To Reproduce
import os
import deepspeed
import deepspeed.comm as dist
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
import torch
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
super(SimpleModel, self).__init__()
self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for i in range(nlayers)])
if empty_grad:
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.empty_grad = empty_grad
def forward(self, x, y):
if len(self.linears) == 1:
x = self.linears[0](x)
else:
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return self.cross_entropy_loss(x, y)
def random_dataset(total_samples, hidden_dim, device, dtype=torch.half):
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
return train_dataset
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
batch_size = model.train_micro_batch_size_per_gpu()
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader
tmpdir = "/home/eeisenst/workspace/temp/temp" # CHANGE THIS TO SOMETHING CONVENIENT
zero_dir, ckpt_dir = os.path.join(tmpdir, "zero"), os.path.join(tmpdir, "checkpoint")
torch.manual_seed(12345)
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"fp16": {
"enabled": True,
"initial_scale_power": 2
},
"zero_optimization": {
"stage": 3,
"offload_param": {
"device": OffloadDeviceEnum.cpu
# "device": OffloadDeviceEnum.nvme,
# "nvme_path": str(zero_dir)
},
"offload_optimizer": {
# "device": OffloadDeviceEnum.cpu
"device": OffloadDeviceEnum.nvme,
"nvme_path": str(zero_dir)
},
"sub_group_size": 100,
"stage3_max_live_parameters": 100,
"stage3_param_persistence_threshold": 0,
},
"aio": {
"block_size": 1048576 # Minimum AIO bytes, anything smaller than this will not be offloaded
}
}
hidden_dim, nlayers = 2048, 5
with deepspeed.zero.Init(config_dict_or_path=config_dict):
model = SimpleModel(hidden_dim, nlayers=nlayers, empty_grad=False)
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device,
dtype=torch.float16)
dist.barrier()
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
Expected behavior
This script should exit with no error.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [YES] ...... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
[WARNING] Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.0
[WARNING] using untested triton version (2.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/eeisenst/miniconda3/envs/deepspeed-test/lib/python3.11/site-packages/torch']
torch version .................... 2.0.1
deepspeed install path ........... ['/home/eeisenst/workspace/DeepSpeed/deepspeed']
deepspeed info ................... 0.10.4+6c6a1ec0, 6c6a1ec0, nvme_ckpt
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 11.7
deepspeed wheel compiled w. ...... torch 2.0, cuda 11.8
shared memory (/dev/shm) size .... 15.57 GB
System info (please complete the following information):
- OS: Fedora 37
- GPU count and types: 1x 3080 Ti on test machine, but it doesn't seem to matter
- Python version: 3.11.5
Environment:
mamba install -c conda-forge pip python pytest pytorch gcc=11 libaio rust cmake
Build command:
CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" DS_BUILD_CPU_ADAM=1 DS_BUILD_AIO=1 DS_BUILD_UTILS=1 pip install -e . --global-option="build_ext" --global-option="-j8" --no-cache -v
Launcher context
deepspeed
launcher
Docker context
No docker.
Additional context
It seems that the problem is being caused by the following two lines 1334-1335 in deepspeed/runtime/zero/stage3.py
in DeepSpeedZeroOptimizer_Stage3.partition_grads
:
# offload the gradient partition if applicable
if self.offload_optimizer:
i, dest_offset, _ = self.grad_position[self.get_param_id(param)]
offload_fp32_gradients = {} # THIS IS THE BUG???
offload_fp32_offsets = {} # THIS IS THE BUG???
This resets the dictionary of offloaded gradients so that, later in the same function, lines 1357-1361 do nothing:
if self.offload_optimizer and self.swap_optimizer:
for i in offload_fp32_gradients.keys():
self.optimizer_swapper.swap_out_gradients(parameter=self.fp32_partitioned_groups_flat[i],
gradient_offsets=offload_fp32_offsets[i],
gradient_tensors=offload_fp32_gradients[i])
Commenting the lines marked BUG causes the script to work as expected.
Activity