Skip to content

[BUG] In deepspeed Zero3, RuntimeError: still have inflight params #5828

Closed
@XuyaoWang

Description

Describe the bug

When using zero stage 3, an issue occurs during training where some model parameters are selected based on data content. The error message is: RuntimeError: still have inflight params.

To Reproduce
Steps to reproduce the behavior:

  1. Create a python file
touch test_inflight.py
  1. Paste the following code into test_inflight.py.
import argparse
import deepspeed
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.optim import AdamW
from transformers import get_scheduler
from transformers.deepspeed import HfDeepSpeedConfig


class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.fc1 = nn.Linear(1024, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 1)
        self.selu = nn.SELU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x, use_fc2):
        x = self.fc1(x)
        x = self.selu(x)
        if use_fc2:
            x = self.fc2(x)
        x = self.selu(x)
        x = self.fc3(x)
        x = self.sigmoid(x)
        return x

def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=-1)
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()

    model = MyNetwork()

    deepspeed.init_distributed()

    torch.cuda.set_device(args.local_rank)
    device = torch.device('cuda', args.local_rank)
    args.device = device
    args.global_rank = dist.get_rank()

    dist.barrier()

    ds_config = {
        'train_batch_size': None,
        'train_micro_batch_size_per_gpu': 8,
        'gradient_accumulation_steps': 1,
        'steps_per_print': 10,
        'zero_optimization': {
            'stage': 3,
            'offload_param': {
                'device': 'none',
            },
            'offload_optimizer': {
                'device': 'none',
            },
            'param_persistence_threshold': 1e4,
            'max_live_parameters': 3e7,
            'prefetch_bucket_size': 3e7,
            'memory_efficient_linear': False,
            'gather_16bit_weights_on_model_save': True,
        },
        'gradient_clipping': 1.0,
        'prescale_gradients': False,
        'wall_clock_breakdown': False,
    }


    _dstchf = HfDeepSpeedConfig(ds_config)
    
    optimizer = AdamW(
        [{'params': list(model.parameters()), 'weight_decay': 0.0}],
        lr=1e-3,
        betas=(0.9, 0.95),
    )

    lr_scheduler = get_scheduler(
        name='cosine',
        optimizer=optimizer,
        num_warmup_steps=5,
        num_training_steps=100,
    )

    model, *_ = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        args=args,
        config=ds_config,
        lr_scheduler=lr_scheduler,
        dist_init_required=True,
    )
    
    inputs = torch.randn(8, 1024).to(device)
    predicts = torch.randn(8, 1).to(device)
    outputs = model(inputs, use_fc2=True)
    loss = nn.MSELoss()(outputs, predicts)
    model.backward(loss)
    model.step()
    
    inputs = torch.randn(8, 1024).to(device)
    predicts = torch.randn(8, 1).to(device)
    outputs = model(inputs, use_fc2=False)
    loss = nn.MSELoss()(outputs, predicts)
    model.backward(loss)
    model.step()

if __name__ == '__main__':
    main()
  1. Run the following command.
deepspeed --module test_inflight

Expected behavior
A clear and concise description of what you expected to happen.

ds_report output
Please run ds_report to give us details about your setup.

[2024-08-05 21:31:27,078] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Warning: The default cache directory for DeepSpeed Triton autotune, /home/yangyaodong/.triton/autotune, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path.
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
/aifs4su/yaodong/miniconda3/envs/xuyao-multi-node/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:47: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  @autocast_custom_fwd
/aifs4su/yaodong/miniconda3/envs/xuyao-multi-node/lib/python3.11/site-packages/deepspeed/runtime/zero/linear.py:66: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  @autocast_custom_bwd
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/aifs4su/yaodong/miniconda3/envs/xuyao-multi-node/lib/python3.11/site-packages/torch']
torch version .................... 2.4.0+cu121
deepspeed install path ........... ['/aifs4su/yaodong/miniconda3/envs/xuyao-multi-node/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.14.4, unknown, unknown
torch cuda version ............... 12.1
torch hip version ................ None
nvcc version ..................... 12.2
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 1007.78 GB

Screenshots
If applicable, add screenshots to help explain your problem.
tmpD32

System info (please complete the following information):

  • OS: Ubuntu 22.04.2 LTS
  • GPU count and types: one machine with x8 H800s
  • Python version: Python 3.11.0

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?

deepspeed --module test_inflight

Docker context
Are you using a specific docker image that you can share?

No

Additional context
Add any other context about the problem here.

No

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions