Skip to content

tuple index out of range in _exec_send_grads p2p.send #884

@drcege

Description

@drcege

Describe the bug

When setting both pipe-parallel-size and model-parallel-size to 2, the training crashes. However, setting each individually to 2 (and keep the other as 1) works fine.

The stack trace:

...
done with setups ...                                                                                                                                                                 
time (ms) | model and optimizer: 5943.63 | train/valid/test data iterators: 1088.62                                                                                                  
training ...                                                                                                                                                                         
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:553:forward] Activation Checkpointing Information                                                                                 
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:554:forward] ----Partition Activations True, CPU CHECKPOINTING False                                                               
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:557:forward] ----contiguous Memory Checkpointing False with 24 total layers                                                        
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:560:forward] ----Synchronization True                                                                                              
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:561:forward] ----Profiling time in checkpointing False 
 Traceback (most recent call last):                 
   File "train.py", line 27, in <module>                       
     pretrain(neox_args=neox_args)               
   File "/gpt-neox/megatron/training.py", line 226, in pretrain
     pretrain(neox_args=neox_args)  
   File "/gpt-neox/megatron/training.py", line 226, in pretrain
     iteration = train(
   File "/gpt-neox/megatron/training.py", line 778, in train
     iteration = train(
   File "/gpt-neox/megatron/training.py", line 778, in train
     loss_dict, skipped_iter = train_step(
   File "/gpt-neox/megatron/training.py", line 684, in train_step
     loss_dict, skipped_iter = train_step(
   File "/gpt-neox/megatron/training.py", line 684, in train_step
     reduced_loss = train_step_pipe(
   File "/gpt-neox/megatron/training.py", line 734, in train_step_pipe
     reduced_loss = train_step_pipe(
   File "/gpt-neox/megatron/training.py", line 734, in train_step_pipe
     loss = model.train_batch(data_iter=data_iterator)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
     self._exec_schedule(sched)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1378, in _exec_schedule
     loss = model.train_batch(data_iter=data_iterator)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
     self._exec_schedule(sched)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1378, in _exec_schedule
     self._exec_instr(**cmd.kwargs)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1025, in _exec_send_grads
     p2p.send(inputs[1], self.prev_stage)
 IndexError: tuple index out of range
...

To Reproduce

Steps to reproduce the behavior:

  1. Use the latest docker images leogao2/gpt-neox:sha-61b5eee
  2. I start two containers each with four visible gpus to simulate two-node distributed training
  3. Properly set up passwordless SSH login between containers
  4. Write /job/hostfile as follows
container1-ip slots=4
container2-ip slots=4
  1. Execute python ./deepy.py train.py -d configs 1-3B.yml local_setup.yml after modifying data-path, vocab-file, and merge-file

Expected behavior

Should train without error.

Proposed solution

After debugging, I believe the error was triggered here:
https://github.com/EleutherAI/DeeperSpeed/blob/457850dc5ad72960f0e8a8f1597914d682a7792c/deepspeed/runtime/pipe/engine.py#L1023-L1025

It seems that the length of inputs is less than 2, so the indexing is out of range. Does this means the grad is not properly partitioned when model-parallel-size>1?

I know the code comes from inside DeepSpeed, but these lines were written several years ago and have been used by many tools, suggesting that the error may be caused by incorrect external passing of NeoX.

Screenshots
If applicable, add screenshots to help explain your problem.

Environment (please complete the following information):

  • GPUs: 8x V100
  • Configs: 1-3B.yml and local_setup.yml with above modifications

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions