Skip to content

Conversation

@vx120
Copy link
Contributor

@vx120 vx120 commented Dec 24, 2025

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

  • Add the MuonClip optimizer to support Q/K projection clipping during training.
  • Implement max_logits collection from attention forward passes to enable QK clipping.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @vx120, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the MuonClip optimizer, a new optimization strategy specifically tailored for models utilizing attention mechanisms. Its primary goal is to enhance training stability and potentially performance by dynamically clipping Query (Q) and Key (K) projection weights. This is achieved through a novel _MaxLogitsTracker that intelligently monitors maximum attention logits across different attention implementations (eager softmax, SDPA, and FlashAttention) and uses this information to apply a scaled clipping factor to QK projections during the optimization step. Additionally, the optimizer employs a Newton-Schulz method for orthogonalizing parameter updates.

Highlights

  • Introduction of MuonClip Optimizer: A new optimizer, MuonClip, has been added, designed to apply specific clipping strategies to attention Query (Q) and Key (K) projections.
  • Dynamic Max Logits Tracking: Implements _MaxLogitsTracker to automatically collect maximum attention logits from various attention implementations (eager softmax, SDPA, FlashAttention) to inform the clipping process.
  • QK Projection Clipping: Integrates a mechanism within the MuonClip optimizer to clip Query (Q) and Key (K) projection weights based on the tracked maximum logits and a configurable qk_clip_tau.
  • Orthogonalization for Updates: The MuonClip optimizer incorporates a Newton-Schulz iteration for orthogonalizing parameter updates, particularly for 2D parameters.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the MuonClip optimizer, which adds support for Q/K projection clipping during training. The implementation is well-structured, adding a new optimizer in swift/plugin/muonclip.py and integrating it via a factory function in swift/plugin/optimizer.py. The use of monkey-patching to track attention logits is a notable aspect of the implementation. My review has identified a high-severity bug in the optimizer's argument handling that could lead to runtime errors, as well as some medium-severity issues related to exception handling, code redundancy, and maintainability. The provided feedback includes specific suggestions to address these points and improve the overall quality of the code.

all_params = [p for _, p in model.named_parameters() if p.requires_grad]
param_groups = [{"params": all_params, "lr": args.learning_rate, "is_qk": False}]

allowed = {"lr", "momentum", "weight_decay", "nesterov", "newton_schulz_steps", "qk_clip_tau"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The allowed set for optimizer arguments is configured incorrectly.

  1. It includes lr and weight_decay, which are already passed as explicit keyword arguments to MuonClip. This will cause a TypeError if a user also provides these in optim_args.
  2. It's missing qk_clip_enabled, which prevents this important parameter from being configured via optim_args.

This should be corrected to allow proper configuration and prevent runtime errors.

Suggested change
allowed = {"lr", "momentum", "weight_decay", "nesterov", "newton_schulz_steps", "qk_clip_tau"}
allowed = {"momentum", "nesterov", "newton_schulz_steps", "qk_clip_tau", "qk_clip_enabled"}

return
try:
v = float(v)
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors and make debugging more difficult. It's better to catch specific exceptions that you expect to handle, such as ValueError or TypeError in this case of float conversion.

Suggested change
except Exception:
except (ValueError, TypeError):

try:
import flash_attn.flash_attn_interface as _fai
flash_attn_func = _fai.flash_attn_func
except Exception:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a broad Exception for an import is not ideal. It's better to specifically catch ImportError to make the intent clear and avoid masking other potential issues during initialization.

Suggested change
except Exception:
except ImportError:

Comment on lines +310 to +315
def build_muon_param_groups(
model,
lr=0.02,
weight_decay=0.0,
qk_ratio=0.1
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The qk_ratio parameter is defined but not used within the build_muon_param_groups function. This unused parameter should be removed to improve code clarity.

Additionally, this entire function appears to be redundant, as a more detailed parameter grouping logic is implemented in create_muon_clip_optimizer in swift/plugin/optimizer.py. Consider removing this function if it's not used elsewhere to avoid code duplication and potential confusion.

def build_muon_param_groups(
    model,
    lr=0.02,
    weight_decay=0.0
):

continue

is_muon_candidate = (p.ndim >= 2 and embed_key not in name and lm_head_key not in name)
is_qk_name = any(x in name.lower() for x in ['wq', 'wk', 'q_proj', 'k_proj', 'query', 'key'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of strings ['wq', 'wk', 'q_proj', 'k_proj', 'query', 'key'] used to identify Q/K projection weights is hardcoded within the loop. For better readability and maintainability, it's good practice to define this list as a constant at the beginning of the function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant