-
Notifications
You must be signed in to change notification settings - Fork 3k
[merger] fix: support vision_model keys in Megatron merger for VL models #4701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Add vision_model to the skip_checking_keys list to allow merging Vision-Language model checkpoints. Previously, keys like vision_model.patch_embed.proj.weight would fail validation because they don't start with 'decoder'. Fixes volcengine#4498 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for vision_model keys in the Megatron model merger, which is necessary for handling Vision-Language models. The changes are applied consistently to both legacy_model_merger.py and megatron_model_merger.py. The core logic of adding vision_model to the skip_checking_keys list is correct.
My main feedback is on improving the clarity of the docstrings and error messages within the _check_megatron_state_key function in both files. The current descriptions are a bit misleading about the key validation logic, and I've provided suggestions to make them more accurate and maintainable.
| Now the model merger only supports keys that start with "decoder/embedding/output_layer/vision_model" in TransformerLayer. | ||
| Shall not use key starts with "model." | ||
| """ | ||
| if key.startswith("model."): | ||
| raise ValueError( | ||
| f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer." | ||
| f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer/vision_model' in TransformerLayer." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated docstring and error message are a bit misleading. They state that keys are expected to start with a specific set of prefixes, but the logic is more nuanced. The code actually checks if keys contain certain substrings (like vision_model) to skip them, or if they start with decoder. This discrepancy can be confusing for future maintenance.
To improve clarity and prevent potential bugs, I suggest updating the docstring and error message to more accurately reflect the validation logic.
| Now the model merger only supports keys that start with "decoder/embedding/output_layer/vision_model" in TransformerLayer. | |
| Shall not use key starts with "model." | |
| """ | |
| if key.startswith("model."): | |
| raise ValueError( | |
| f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer' in TransformerLayer." | |
| f"Invalid key {key} in Megatron state_dict. Expected keys to start with 'decoder/embedding/output_layer/vision_model' in TransformerLayer." | |
| ) | |
| Now the model merger supports keys for decoder, embedding, output_layer, and vision_model components. | |
| Keys starting with "model." are disallowed. | |
| """ | |
| if key.startswith("model."): | |
| raise ValueError( | |
| f"Invalid key {key} in Megatron state_dict. Keys starting with 'model.' are not allowed. " | |
| f"Expected keys for 'decoder', 'embedding', 'output_layer', or 'vision_model'." | |
| ) |
| Now the model merger only supports keys that start with "decoder/embedding/output_layer/vision_model" in TransformerLayer. | ||
| Shall not use key starts with "model." | ||
| """ | ||
| if key.startswith("model."): | ||
| raise ValueError( | ||
| f"Invalid key {key} in Megatron state_dict. Expected keys to start with " | ||
| f"'decoder/embedding/output_layer' in TransformerLayer." | ||
| f"'decoder/embedding/output_layer/vision_model' in TransformerLayer." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The updated docstring and error message are a bit misleading. They state that keys are expected to start with a specific set of prefixes, but the logic is more nuanced. The code actually checks if keys contain certain substrings (like vision_model) to skip them, or if they start with decoder. This discrepancy can be confusing for future maintenance.
To improve clarity and prevent potential bugs, I suggest updating the docstring and error message to more accurately reflect the validation logic.
| Now the model merger only supports keys that start with "decoder/embedding/output_layer/vision_model" in TransformerLayer. | |
| Shall not use key starts with "model." | |
| """ | |
| if key.startswith("model."): | |
| raise ValueError( | |
| f"Invalid key {key} in Megatron state_dict. Expected keys to start with " | |
| f"'decoder/embedding/output_layer' in TransformerLayer." | |
| f"'decoder/embedding/output_layer/vision_model' in TransformerLayer." | |
| ) | |
| Now the model merger only supports keys that start with "decoder/embedding/output_layer/vision_model" in TransformerLayer. | |
| Shall not use key starts with "model." | |
| """ | |
| if key.startswith("model."): | |
| raise ValueError( | |
| f"Invalid key {key} in Megatron state_dict. Keys starting with 'model.' are not allowed. " | |
| f"Expected keys for 'decoder', 'embedding', 'output_layer', or 'vision_model'." | |
| ) |
Summary
vision_modelto theskip_checking_keyslist in_check_megatron_state_key()methodmegatron_model_merger.pyandlegacy_model_merger.pyThis allows merging Vision-Language model checkpoints (like Qwen2.5-VL-7B) back to HuggingFace format. Previously, keys like
vision_model.patch_embed.proj.weightwould fail validation because the merger expected all keys to start with 'decoder'.Fixes #4498
Test plan
🤖 Generated with Claude Code