Conversation
Summary of ChangesHello @zianglih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the model quantization pipeline by introducing the ability to perform partial MXFP8 quantization. Specifically, it allows users to designate a certain number of the last decoder layers to retain their original BF16 precision, while the rest of the model is converted to MXFP8. This feature provides finer-grained control over model precision, potentially balancing performance and accuracy by preserving higher precision in critical layers. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a feature to skip quantization for the last N layers of a model, keeping them in BF16 format. The changes span across model conversion scripts and training execution scripts to support this new functionality, primarily for MXFP8 quantization. My review identifies a logical contradiction in the feature's activation condition, a magic number that hurts maintainability, and some redundant and non-robust file handling. Addressing these points will improve the code's correctness and quality.
| if (args.train_fp8 or args.train_mxfp8) and args.num_layers_at_end_in_bf16 > 0: | ||
| misc_args += ( | ||
| "--first-last-layers-bf16 " | ||
| "--num-layers-at-start-in-bf16 0 " | ||
| f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} " | ||
| ) |
There was a problem hiding this comment.
There's a logical contradiction regarding when this feature is enabled. The __post_init__ check on line 40 asserts that num_layers_at_end_in_bf16 is only supported when rollout_mxfp8 is enabled. However, this block enables the feature for both train_fp8 and train_mxfp8. If train_fp8 is used (which usually implies rollout_fp8 and not rollout_mxfp8), the assertion on line 40 will fail, making the feature unusable with train_fp8.
Given that other changes in this PR are specific to mxfp8, it seems this feature is intended only for mxfp8. If so, the condition should be narrowed to resolve the contradiction.
| if (args.train_fp8 or args.train_mxfp8) and args.num_layers_at_end_in_bf16 > 0: | |
| misc_args += ( | |
| "--first-last-layers-bf16 " | |
| "--num-layers-at-start-in-bf16 0 " | |
| f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} " | |
| ) | |
| if args.train_mxfp8 and args.num_layers_at_end_in_bf16 > 0: | |
| misc_args += ( | |
| "--first-last-layers-bf16 " | |
| "--num-layers-at-start-in-bf16 0 " | |
| f"--num-layers-at-end-in-bf16 {args.num_layers_at_end_in_bf16} " | |
| ) | |
| num_maybe_mtp_layers = 1 | ||
| dynamic_skip_layer_prefixes: set[str] = { | ||
| f"model.layers.{i}." for i in range(tail_start_idx, num_hidden_layers + num_maybe_mtp_layers) | ||
| } |
There was a problem hiding this comment.
The use of the magic number 1 for num_maybe_mtp_layers makes the code less readable and harder to maintain. It's not immediately clear why this value is 1 and if it's model-specific.
To improve clarity and maintainability, please define this as a named constant with a comment explaining its purpose. For example:
# Number of MTP (Mixture of Transformer Parallel) layers to account for, which might not be included in `num_hidden_layers`.
# This can be model-specific.
NUM_MAYBE_MTP_LAYERS = 1
# ...
# ... range(tail_start_idx, num_hidden_layers + NUM_MAYBE_MTP_LAYERS)| config_path = os.path.join(input_path, "config.json") | ||
| with open(config_path) as f: | ||
| cfg = json.load(f) | ||
| num_hidden_layers = int(cfg["num_hidden_layers"]) |
There was a problem hiding this comment.
This block reads config.json, but the same file is read again at line 198, which is redundant. There are also inconsistencies in file handling throughout the function:
- The
opencall here at line 155 is not protected byos.path.exists, which will cause a crash if the file is missing. The later read at line 198 is protected. - File operations at lines 198 and 200 do not use a
withstatement, which is not best practice.
Please consider refactoring to load the config once at the start of the function, and use with statements for all file I/O to ensure robustness and proper resource management.
d5236a8 to
56515e8
Compare
@HumansAnd
Motivation
As mentioned in https://arxiv.org/abs/2509.25149:
is helpful for training convergence.
Dependency
sgl-project/sglang#18742 for serving mixed-precision checkpoints