feat(kimi_k25): implement multimodal get_inputs_embeds for Megatron training#106
feat(kimi_k25): implement multimodal get_inputs_embeds for Megatron training#106zhihanliu-collab wants to merge 2 commits into
Conversation
…raining KimiK25Vit previously raised NotImplementedError for any batch containing pixel_values, restricting Kimi-K2.5/K2.6 to text-only Megatron training. This implements the vision path, mirroring KimiK25ForConditionalGeneration.forward: - vision_tower (MoonViT3d) produces un-projected features; mm_projector (PatchMergerMLP) maps them to the language hidden size - image features are masked-scattered into inputs_embeds at media_placeholder token positions - plain-text batches run the vision graph on a dummy image (scaled by 0) so DP ranks stay in sync during gradient all-reduce The language model (text_config.model_type='kimi_k2', MLA + fine-grained MoE) is already handled by the config parser and GPTBridge weight conversion, so no language-side changes are needed. CP/SP slicing of inputs_embeds is handled by the existing MultimodalGPTModel embedding patch (no deepstack features for K25). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request enables multimodal input support for the Kimi-K2.5 model by implementing image encoding and updating the embedding retrieval logic to handle image features. For plain-text batches, a dummy forward pass is executed to keep DP ranks synchronized. Feedback suggests using torch.nan_to_num to prevent potential NaN propagation from the dummy pass and implementing defensive attribute access for the vision configuration.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| vision_config = self.hf_config.vision_config | ||
| patch_size = vision_config.patch_size | ||
| merge_kernel_size = vision_config.merge_kernel_size | ||
| kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) else merge_kernel_size | ||
| h = w = kernel * 2 | ||
| dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device) | ||
| dummy_grid = input_ids.new_tensor([[1, h, w]]) | ||
| image_features = self._encode_images(dummy_pixels, dummy_grid) | ||
| inputs_embeds = inputs_embeds + image_features.mean().to( | ||
| device=inputs_embeds.device, dtype=inputs_embeds.dtype) * 0. |
There was a problem hiding this comment.
Numerical Instability and Robustness Safeguards
-
NaN Propagation Prevention (High Severity): Running a dummy forward pass on an all-zero tensor (
torch.zeros) can sometimes lead to numerical instabilities (e.g., division by zero or zero-variance in normalization layers) in the vision tower, resulting inNaNvalues inimage_features. Ifimage_featurescontainsNaN,image_features.mean() * 0.will evaluate toNaN(sinceNaN * 0isNaNin IEEE 754). This will propagateNaNtoinputs_embeds, corrupting the entire plain-text batch and causing training divergence. Usingtorch.nan_to_numon the zero-multiplier term ensures that any numerical instability in the dummy pass is safely zeroed out and does not affect text-only training. -
Defensive Attribute Access (Medium Severity): Accessing
vision_config.patch_sizeandvision_config.merge_kernel_sizedirectly can raiseAttributeErrororTypeErrorif they are missing orNone. Usinggetattrwith safe fallbacks is more robust. Additionally, checkinglen(merge_kernel_size) > 0preventsIndexErrorifmerge_kernel_sizeis an empty list or tuple.
| vision_config = self.hf_config.vision_config | |
| patch_size = vision_config.patch_size | |
| merge_kernel_size = vision_config.merge_kernel_size | |
| kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) else merge_kernel_size | |
| h = w = kernel * 2 | |
| dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device) | |
| dummy_grid = input_ids.new_tensor([[1, h, w]]) | |
| image_features = self._encode_images(dummy_pixels, dummy_grid) | |
| inputs_embeds = inputs_embeds + image_features.mean().to( | |
| device=inputs_embeds.device, dtype=inputs_embeds.dtype) * 0. | |
| vision_config = self.hf_config.vision_config | |
| patch_size = getattr(vision_config, "patch_size", 14) | |
| merge_kernel_size = getattr(vision_config, "merge_kernel_size", 2) or 2 | |
| kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) and len(merge_kernel_size) > 0 else merge_kernel_size | |
| h = w = kernel * 2 | |
| dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device) | |
| dummy_grid = input_ids.new_tensor([[1, h, w]]) | |
| image_features = self._encode_images(dummy_pixels, dummy_grid) | |
| zero_loss_term = torch.nan_to_num(image_features.mean() * 0.) | |
| inputs_embeds = inputs_embeds + zero_loss_term.to( | |
| device=inputs_embeds.device, dtype=inputs_embeds.dtype) |
torch.nan_to_num on the zero-scaled dummy term prevents a non-finite value from the all-zero dummy forward (e.g. zero-variance normalization) leaking into text-only batches, since NaN * 0 == NaN in IEEE-754. Addresses gemini-code-assist review feedback. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Thanks for the review. Adopted (high severity — NaN propagation): Wrapped the zero-scaled dummy term in Declined (medium severity — getattr fallbacks): Kept direct access to Verified on a CPU smoke test that instantiates the real |
What
KimiK25Vit.get_inputs_embedspreviously raisedNotImplementedErrorfor any batch containingpixel_values, restricting Kimi-K2.5 / Kimi-K2.6 to text-only Megatron training. This PR implements the vision path so the 1T MoE can be trained with image data under Megatron (EP/PP/TP/CP).How
Mirrors
KimiK25ForConditionalGeneration.forwardfrom the HF modeling code:vision_tower(MoonViT3d) produces un-projected feature tensors (list), andmm_projector(PatchMergerMLP) maps them to the language hidden size —_extract_image_featuresalone does not apply the projector, so both must be called.masked_scatter-ed intoinputs_embedsatmedia_placeholder_token_idpositions (canonical pattern, same asinternvl/qwen3_vl).0., so DP ranks stay in sync during gradient all-reduce.Notes
text_config.model_type='kimi_k2', MLA + fine-grained MoE) is already handled by the config parser andGPTBridgeweight conversion — no language-side changes needed; text-only training already worked.inputs_embedsis handled by the existingMultimodalGPTModelembedding patch. K25 has no deepstack features, soget_inputs_embedsreturns a plain tensor (novisual_pos_masks/deepstack_visual_embeds).KimiK25Templateimage encoding + collation (pixel_values/grid_thws) and the non-first-PP-stage MM tensor stripping.Testing
tests/test_mllm.py::test_kimi_k25added (HF↔mcore convert-precision roundtrip; left commented in__main__like other large models).<|media_pad|>token count).🤖 Generated with Claude Code