-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Support muonclip optimizer #7191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Removed debug print statements for attention implementation.
Summary of ChangesHello @vx120, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces the MuonClip optimizer, which adds support for Q/K projection clipping during training. The implementation is well-structured, adding a new optimizer in swift/plugin/muonclip.py and integrating it via a factory function in swift/plugin/optimizer.py. The use of monkey-patching to track attention logits is a notable aspect of the implementation. My review has identified a high-severity bug in the optimizer's argument handling that could lead to runtime errors, as well as some medium-severity issues related to exception handling, code redundancy, and maintainability. The provided feedback includes specific suggestions to address these points and improve the overall quality of the code.
| all_params = [p for _, p in model.named_parameters() if p.requires_grad] | ||
| param_groups = [{"params": all_params, "lr": args.learning_rate, "is_qk": False}] | ||
|
|
||
| allowed = {"lr", "momentum", "weight_decay", "nesterov", "newton_schulz_steps", "qk_clip_tau"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The allowed set for optimizer arguments is configured incorrectly.
- It includes
lrandweight_decay, which are already passed as explicit keyword arguments toMuonClip. This will cause aTypeErrorif a user also provides these inoptim_args. - It's missing
qk_clip_enabled, which prevents this important parameter from being configured viaoptim_args.
This should be corrected to allow proper configuration and prevent runtime errors.
| allowed = {"lr", "momentum", "weight_decay", "nesterov", "newton_schulz_steps", "qk_clip_tau"} | |
| allowed = {"momentum", "nesterov", "newton_schulz_steps", "qk_clip_tau", "qk_clip_enabled"} |
| return | ||
| try: | ||
| v = float(v) | ||
| except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| try: | ||
| import flash_attn.flash_attn_interface as _fai | ||
| flash_attn_func = _fai.flash_attn_func | ||
| except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def build_muon_param_groups( | ||
| model, | ||
| lr=0.02, | ||
| weight_decay=0.0, | ||
| qk_ratio=0.1 | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The qk_ratio parameter is defined but not used within the build_muon_param_groups function. This unused parameter should be removed to improve code clarity.
Additionally, this entire function appears to be redundant, as a more detailed parameter grouping logic is implemented in create_muon_clip_optimizer in swift/plugin/optimizer.py. Consider removing this function if it's not used elsewhere to avoid code duplication and potential confusion.
def build_muon_param_groups(
model,
lr=0.02,
weight_decay=0.0
):| continue | ||
|
|
||
| is_muon_candidate = (p.ndim >= 2 and embed_key not in name and lm_head_key not in name) | ||
| is_qk_name = any(x in name.lower() for x in ['wq', 'wk', 'q_proj', 'k_proj', 'query', 'key']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR type
PR information