Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions src/mcore_bridge/model/mm_gpts/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class KimiK25Vit(HuggingFaceVit):
module_mapping = {'vision_tower': 'vision_tower', 'mm_projector': 'mm_projector'}
_vision_tower = ['vision_tower']
_aligner = ['mm_projector']
test_mm_type = 'text'

def prepare_model(self, hf_config: PretrainedConfig):
output = []
Expand All @@ -85,10 +84,44 @@ def prepare_model(self, hf_config: PretrainedConfig):
self.vision_tower = MoonViT3dPretrainedModel._from_config(vit_config)
self.mm_projector = PatchMergerMLP(proj_config).to(self.vision_tower.dtype)

def _encode_images(self, pixel_values, grid_thws):
# vision_tower returns a list of un-projected feature tensors; mm_projector
# (PatchMergerMLP) maps them to the language hidden size. Mirrors
# KimiK25ForConditionalGeneration.forward.
image_features = self.vision_tower(pixel_values, grid_thws)
image_features = self.mm_projector(image_features)
return torch.cat(image_features, dim=0)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
pixel_values = kwargs.pop('pixel_values', None)
if pixel_values is not None:
raise NotImplementedError('Kimi-K25 currently only supports plain text training.')
input_ids = kwargs['input_ids']
pixel_values = kwargs.get('pixel_values')
grid_thws = kwargs.get('grid_thws')
dtype = next(self.vision_tower.parameters()).dtype
if pixel_values is not None and pixel_values.size(0) > 0:
if grid_thws is None:
raise KeyError('pixel_values present in inputs but grid_thws is missing')
pixel_values = pixel_values.to(device=inputs_embeds.device, dtype=dtype)
grid_thws = grid_thws.to(inputs_embeds.device)
image_features = self._encode_images(pixel_values, grid_thws)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
media_token_id = self.hf_config.media_placeholder_token_id
image_mask = (input_ids == media_token_id).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
else:
# plain-text batch: still run the vision graph on a dummy image so that
# DP ranks stay in sync during gradient all-reduce.
vision_config = self.hf_config.vision_config
patch_size = vision_config.patch_size
merge_kernel_size = vision_config.merge_kernel_size
kernel = merge_kernel_size[0] if isinstance(merge_kernel_size, (list, tuple)) else merge_kernel_size
h = w = kernel * 2
dummy_pixels = torch.zeros((h * w, 3, patch_size, patch_size), dtype=dtype, device=inputs_embeds.device)
dummy_grid = input_ids.new_tensor([[1, h, w]])
image_features = self._encode_images(dummy_pixels, dummy_grid)
# nan_to_num guards against a non-finite value from the all-zero dummy pass
# leaking into the text batch (NaN * 0 == NaN in IEEE-754).
zero_term = torch.nan_to_num(image_features.mean() * 0.)
inputs_embeds = inputs_embeds + zero_term.to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
return inputs_embeds


Expand Down
5 changes: 5 additions & 0 deletions tests/test_mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ def test_kimi_vl():
_test_model('moonshotai/Kimi-VL-A3B-Thinking-2506')


def test_kimi_k25():
_test_model('moonshotai/Kimi-K2.6')


def test_qwen3_vl():
_test_model('Qwen/Qwen3-VL-4B-Instruct')

Expand Down Expand Up @@ -134,6 +138,7 @@ def test_gemma4():
# test_glm4_6v_flash()
# test_ovis2_5()
# test_kimi_vl()
# test_kimi_k25()
# test_qwen3_vl()
# test_qwen3_vl_moe()
# test_qwen3_omni()
Expand Down