-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Description
Describe the bug
When using the Muon optimizer with ZeRO-1 or ZeRO-2 and reduce_scatter=true, parameters that span partition boundaries receive incorrect gradient updates due to incomplete gradient reduction before orthogonalization. This causes the Newton-Schulz orthogonalization in Muon to operate on partially reduced gradients, leading to incorrect parameter updates for cross-partition parameters.
To Reproduce
Steps to reproduce the behavior:
-
Set up a training configuration with:
- ZeRO stage 1 or 2 with
reduce_scatter: truein the DeepSpeed config - Muon optimizer (or MuonWithAuxAdam) for 2D weight matrices
- Multi-GPU training (e.g., 4 GPUs)
- ZeRO stage 1 or 2 with
-
Run training and observe that some parameters (those crossing partition boundaries) receive inconsistent gradient updates across ranks
Expected behavior
All ranks should receive fully reduced gradients before Muon's orthogonalization step, ensuring that:
- The Newton-Schulz iteration operates on complete, correctly reduced gradients
- All ranks compute identical orthogonalized updates for the same parameter
- Cross-partition parameters are updated consistently
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
[WARNING] async_io: please install the libaio-dev package with apt
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
[WARNING] Please specify CUTLASS location directory as environment variable CUTLASS_PATH
[WARNING] Possible values are: a path, DS_IGNORE_CUTLASS_DETECTION and DS_USE_CUTLASS_PYTHON_BINDINGS
evoformer_attn ......... [NO] ....... [NO]
[WARNING] FP Quantizer is using an untested triton version (3.4.0), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [NO] ....... [NO]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
/opt/miniconda3/envs/zswift/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlvsym'
/opt/miniconda3/envs/zswift/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlopen'
/opt/miniconda3/envs/zswift/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlclose'
/opt/miniconda3/envs/zswift/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlerror'
/opt/miniconda3/envs/zswift/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlsym'
collect2: error: ld returned 1 exit status
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.8
[WARNING] using untested triton version (3.4.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/opt/miniconda3/envs/zswift/lib/python3.10/site-packages/torch']
torch version .................... 2.8.0+cu128
deepspeed install path ........... ['/mnt/code/zwd/project/deepspeed-proj/DeepSpeed/deepspeed']
deepspeed info ................... 0.18.5+374f6d09, 374f6d09, master
torch cuda version ............... 12.8
torch hip version ................ None
nvcc version ..................... 12.9
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 1007.44 GB
System info (please complete the following information):
- OS: Anolis OS release 8.8
- GPU count and types: 1 machines with x8 H200
- Python version: 3.10
Additional context
The issue occurs in the gradient reduction pipeline when reduce_scatter=true:
-
In
average_tensor()(stage_1_and_2.py):- When
reduce_scatter=false: Callsgradient_reduction_w_predivide()which performs an in-place all_reduce directly onbucket.buffer, updating all ranks with the complete reduced gradient - When
reduce_scatter=true: Callsallreduce_and_scatter()→allreduce_and_copy_with_multiple_ranks()which:- Creates a new flattened tensor via
self.flatten(bucket)(copying data) - Performs all_reduce on this new tensor
- Only copies results back to the target rank's buffer via
buf.copy_(synced)whendist.get_rank() == bucket_rank
- Creates a new flattened tensor via
- When
-
Memory relationship (stage_1_and_2.py):
new_grad_tensor = bucket.buffer[bucket.index].narrow(0, bucket.elements, param.numel()) grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc)
param.grad_accum.datapoints directly tobucket.buffer- When only target ranks receive reduced gradients, non-target ranks retain local (unreduced) gradients
-
For cross-partition parameters:
- Suppose parameter B spanning rank0 and rank1:
- rank0's
grad_accum= [reduced_B1, local_B2] (only B1 portion is reduced) - rank1's
grad_accum= [local_B1, reduced_B2] (only B2 portion is reduced)
- rank0's
- Suppose parameter B spanning rank0 and rank1:
-
Muon orthogonalization (stage_1_and_2.py):
if getattr(tensor, 'use_muon', False): grad_accum = muon_update(grad_accum, buffer, ...) # Operates on ENTIRE grad_accum
muon_update()performs Newton-Schulz orthogonalization on the entiregrad_accum- With partially reduced gradients, the orthogonalization produces incorrect results
- Each rank extracts its partition portion from the incorrectly orthogonalized result