Skip to content

feat(kimi_k25): implement multimodal get_inputs_embeds for Megatron training#106

Draft
zhihanliu-collab wants to merge 2 commits into
modelscope:mainfrom
zhihanliu-collab:feat/kimi-k25-multimodal
Draft

feat(kimi_k25): implement multimodal get_inputs_embeds for Megatron training#106
zhihanliu-collab wants to merge 2 commits into
modelscope:mainfrom
zhihanliu-collab:feat/kimi-k25-multimodal

Conversation

@zhihanliu-collab
Copy link
Copy Markdown

What

KimiK25Vit.get_inputs_embeds previously raised NotImplementedError for any batch containing pixel_values, restricting Kimi-K2.5 / Kimi-K2.6 to text-only Megatron training. This PR implements the vision path so the 1T MoE can be trained with image data under Megatron (EP/PP/TP/CP).

How

Mirrors KimiK25ForConditionalGeneration.forward from the HF modeling code:

  1. vision_tower (MoonViT3d) produces un-projected feature tensors (list), and mm_projector (PatchMergerMLP) maps them to the language hidden size — _extract_image_features alone does not apply the projector, so both must be called.
  2. Image features are masked_scatter-ed into inputs_embeds at media_placeholder_token_id positions (canonical pattern, same as internvl / qwen3_vl).
  3. Plain-text batches still run the vision graph on a dummy image scaled by 0., so DP ranks stay in sync during gradient all-reduce.

Notes

  • The language model (text_config.model_type='kimi_k2', MLA + fine-grained MoE) is already handled by the config parser and GPTBridge weight conversion — no language-side changes needed; text-only training already worked.
  • CP/SP slicing of inputs_embeds is handled by the existing MultimodalGPTModel embedding patch. K25 has no deepstack features, so get_inputs_embeds returns a plain tensor (no visual_pos_masks / deepstack_visual_embeds).
  • Companion ms-swift PR adds the KimiK25Template image encoding + collation (pixel_values/grid_thws) and the non-first-PP-stage MM tensor stripping.

Testing

  • tests/test_mllm.py::test_kimi_k25 added (HF↔mcore convert-precision roundtrip; left commented in __main__ like other large models).
  • Vision forward verified against the HF reference implementation shapes (vision_tower → list, projector → hidden_size, scatter count == expanded <|media_pad|> token count).

🤖 Generated with Claude Code

…raining

KimiK25Vit previously raised NotImplementedError for any batch containing
pixel_values, restricting Kimi-K2.5/K2.6 to text-only Megatron training.

This implements the vision path, mirroring KimiK25ForConditionalGeneration.forward:
- vision_tower (MoonViT3d) produces un-projected features; mm_projector
  (PatchMergerMLP) maps them to the language hidden size
- image features are masked-scattered into inputs_embeds at media_placeholder
  token positions
- plain-text batches run the vision graph on a dummy image (scaled by 0) so
  DP ranks stay in sync during gradient all-reduce

The language model (text_config.model_type='kimi_k2', MLA + fine-grained MoE)
is already handled by the config parser and GPTBridge weight conversion, so no
language-side changes are needed. CP/SP slicing of inputs_embeds is handled by
the existing MultimodalGPTModel embedding patch (no deepstack features for K25).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables multimodal input support for the Kimi-K2.5 model by implementing image encoding and updating the embedding retrieval logic to handle image features. For plain-text batches, a dummy forward pass is executed to keep DP ranks synchronized. Feedback suggests using torch.nan_to_num to prevent potential NaN propagation from the dummy pass and implementing defensive attribute access for the vision configuration.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +113 to +122
vision_config = self.hf_config.vision_config
patch_size = vision_config.patch_size
merge_kernel_size = vision_config.merge_kernel_size
kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) else merge_kernel_size
h = w = kernel * 2
dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device)
dummy_grid = input_ids.new_tensor([[1, h, w]])
image_features = self._encode_images(dummy_pixels, dummy_grid)
inputs_embeds = inputs_embeds + image_features.mean().to(
device=inputs_embeds.device, dtype=inputs_embeds.dtype) * 0.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Numerical Instability and Robustness Safeguards

  1. NaN Propagation Prevention (High Severity): Running a dummy forward pass on an all-zero tensor (torch.zeros) can sometimes lead to numerical instabilities (e.g., division by zero or zero-variance in normalization layers) in the vision tower, resulting in NaN values in image_features. If image_features contains NaN, image_features.mean() * 0. will evaluate to NaN (since NaN * 0 is NaN in IEEE 754). This will propagate NaN to inputs_embeds, corrupting the entire plain-text batch and causing training divergence. Using torch.nan_to_num on the zero-multiplier term ensures that any numerical instability in the dummy pass is safely zeroed out and does not affect text-only training.

  2. Defensive Attribute Access (Medium Severity): Accessing vision_config.patch_size and vision_config.merge_kernel_size directly can raise AttributeError or TypeError if they are missing or None. Using getattr with safe fallbacks is more robust. Additionally, checking len(merge_kernel_size) > 0 prevents IndexError if merge_kernel_size is an empty list or tuple.

Suggested change
vision_config = self.hf_config.vision_config
patch_size = vision_config.patch_size
merge_kernel_size = vision_config.merge_kernel_size
kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) else merge_kernel_size
h = w = kernel * 2
dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device)
dummy_grid = input_ids.new_tensor([[1, h, w]])
image_features = self._encode_images(dummy_pixels, dummy_grid)
inputs_embeds = inputs_embeds + image_features.mean().to(
device=inputs_embeds.device, dtype=inputs_embeds.dtype) * 0.
vision_config = self.hf_config.vision_config
patch_size = getattr(vision_config, "patch_size", 14)
merge_kernel_size = getattr(vision_config, "merge_kernel_size", 2) or 2
kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) and len(merge_kernel_size) > 0 else merge_kernel_size
h = w = kernel * 2
dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device)
dummy_grid = input_ids.new_tensor([[1, h, w]])
image_features = self._encode_images(dummy_pixels, dummy_grid)
zero_loss_term = torch.nan_to_num(image_features.mean() * 0.)
inputs_embeds = inputs_embeds + zero_loss_term.to(
device=inputs_embeds.device, dtype=inputs_embeds.dtype)

@zhihanliu-collab zhihanliu-collab marked this pull request as draft June 3, 2026 01:49
torch.nan_to_num on the zero-scaled dummy term prevents a non-finite value
from the all-zero dummy forward (e.g. zero-variance normalization) leaking
into text-only batches, since NaN * 0 == NaN in IEEE-754. Addresses
gemini-code-assist review feedback.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@zhihanliu-collab
Copy link
Copy Markdown
Author

Thanks for the review.

Adopted (high severity — NaN propagation): Wrapped the zero-scaled dummy term in torch.nan_to_num so any non-finite value from the all-zero dummy forward can't leak into text-only batches via NaN * 0 == NaN. Applied the same guard to the companion ms-swift _post_encode dummy path for parity.

Declined (medium severity — getattr fallbacks): Kept direct access to vision_config.patch_size / merge_kernel_size. These are required fields of the Kimi-K2.5/K2.6 vision config (confirmed in the model's config.json: patch_size=14, merge_kernel_size=[2,2]), and direct access matches the existing qwen3_vl / internvl vits in this repo. A getattr(..., 14) / getattr(..., 2) magic-number fallback would silently mask a genuinely malformed config instead of failing loudly, which is the worse failure mode here. The len(...) > 0 guard is also unnecessary since the field is always a 2-tuple.

Verified on a CPU smoke test that instantiates the real vision_tower (MoonViT3d) + mm_projector (PatchMergerMLP) from the K2.6 vision config and runs a real 2-image batch end-to-end: token count matches the template's _encode expansion, projected features are (N, 7168), and the dummy path leaves text embeds unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant