Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,18 @@ def _enforce_cpu_offload():

self.low_precision_master_weights_and_grads = self.master_weights_and_grads_dtype != torch.float32

# Check for Muon optimizer usage
self.uses_muon = any(getattr(param, 'use_muon', False) for group in self.optimizer.param_groups for param in group['params'])

if self.reduce_scatter and self.partition_gradients:
valid_reduce_scatter_dtypes = (torch.float16, torch.bfloat16, torch.float32)
assert self.communication_data_type in valid_reduce_scatter_dtypes, f"{self.zero_stage_string} supports {valid_reduce_scatter_dtypes} communication_data_type with reduce scatter enabled. Got: '{self.communication_data_type}'"
assert self.gradient_predivide_factor == 1.0, f"gradient_predivide_factor != 1.0 is not yet supported with {self.zero_stage_string} with reduce scatter enabled"
assert self.postscale_gradients, f"pre-scale gradients is not yet supported with {self.zero_stage_string} with reduce scatter enabled"

# Check for Muon optimizer compatibility with reduce_scatter (applies to both ZeRO-1 and ZeRO-2)
if self.reduce_scatter and self.uses_muon:
assert False, f"{self.zero_stage_string} with reduce_scatter=True is incompatible with Muon optimizer. Please disable reduce_scatter or use a different optimizer."
Comment on lines +297 to +298

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allow Muon fallback instead of aborting on reduce_scatter

This new assert False makes ZeRO-1/2 initialization fail whenever reduce_scatter=True and any parameter has use_muon=True, which blocks the exact training configuration this change is trying to handle. The later average_tensor change already adds a fallback path (all-reduce when Muon is present), so this assertion prevents that fix from ever running in normal Python execution and turns a correctness bug into a hard runtime failure for Muon users.

Useful? React with 👍 / 👎.


# param flattened by groups
self.bit16_groups = []
Expand Down Expand Up @@ -1212,7 +1219,9 @@ def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dt
stream = get_accelerator().current_stream()

with get_accelerator().stream(stream):
if not self.reduce_scatter:
# Use pre-detected Muon flag from initialization
if not self.reduce_scatter or self.uses_muon:
# Force full all-reduce for Muon parameters even when reduce_scatter is enabled
self.gradient_reduction_w_predivide(tensor, communication_data_type)
return

Expand Down
Loading