Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions mediapipe/tasks/python/genai/converter/safetensors_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def __init__(
backend: str,
reader: _SafetensorsReader,
is_v2: bool,
is_nested: bool = False,
):
super().__init__(
is_symmetric=is_symmetric,
Expand All @@ -402,6 +403,7 @@ def __init__(
)
self._reader = reader
self._is_v2 = is_v2
self._is_gemma3n = is_nested

def map_to_actions(
self, layer_name: str
Expand Down Expand Up @@ -458,7 +460,8 @@ def update_target_name(self, target_name: str) -> str:
"""Updates the target name to match the tensor name convention."""

# For removing multimodality stack from Gemma3-4B
target_name = target_name.replace("language_model.", "")
if self._is_nested:
target_name = target_name.replace("language_model.", "")

target_name = target_name.replace("base_model.model.", "")
target_name = target_name.replace(
Expand Down Expand Up @@ -609,15 +612,25 @@ def __init__(
"GEMMA3_12B",
"GEMMA3_27B",
"GEMMA3_300M",
"GEMMA3N_2B",
"GEMMA3N_4B",
"GEMMA3N_8B",
"GEMMA_3N_E2B_IT",
"GEMMA_3N_E4B_IT",
]:
# Identify all models that have the nested 'language_model.' prefix
nested_models = ["GEMMA3-4B"] + [m for m in special_model if "3N" in m.upper()]
is_nested_model = special_model in nested_models

self.mapper = GemmaMapper(
is_symmetric,
attention_quant_bits,
feedforward_quant_bits,
embedding_quant_bits,
backend,
self._reader,
False if special_model in ["GEMMA_2B", "GEMMA_7B"] else True,
is_v2=(special_model not in ["GEMMA_2B", "GEMMA_7B"]),
is_nested=is_nested_model # <-- Pass the corrected flag
)
else:
raise ValueError(f"Unknown special model: {special_model}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.


"""Unit tests for safetensors_converter."""

import os
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np

from mediapipe.tasks.python.genai.converter import safetensors_converter
from mediapipe.tasks.python.test import test_utils
Expand Down Expand Up @@ -78,6 +81,60 @@ def test_load_to_actions(self, quant_bits):
actions = loader.load_to_actions()
self.assertLen(list(actions), 15)

@parameterized.named_parameters(
('gemma_3n_nested', 'GEMMA3N_4B'),
('gemma_3_4b_nested', 'GEMMA3-4B'),
)
@mock.patch.object(safetensors_converter, '_SafetensorsReader')
def testNestedGemmaConversion(self, model_name, MockReader):
"""Tests that nested Gemma models have their prefixes stripped."""
mock_reader_instance = MockReader.return_value
gemma_nested_variable_names = [
# Standard language model layers with the 'language_model.' prefix
'language_model.model.embed_tokens.weight',
'language_model.model.layers.0.input_layernorm.weight',
'language_model.model.layers.0.mlp.down_proj.weight',
'language_model.model.layers.0.self_attn.o_proj.weight',
'language_model.model.norm.weight',
# Vision tower layers that should be skipped
'vision_tower.vision_tower.encoder.layers.0.blocks.0.attn.qkv.weight',
'multi_modal_projector.linear_1.weight',
]
mock_reader_instance.get_tensor_names.return_value = gemma_nested_variable_names
mock_reader_instance.read_tensor_as_numpy.return_value = np.zeros(
(1, 1), dtype=np.float32
)

loader = safetensors_converter.SafetensorsCkptLoader(
ckpt_path='/fake/path',
is_symmetric=True,
attention_quant_bits=8,
feedforward_quant_bits=8,
embedding_quant_bits=8,
special_model=model_name, # Use the parameterized model name
backend='gpu',
)
actions_list = list(loader.load_to_actions())

# Check that the vision layers were skipped, and only 5 actions were created
self.assertLen(actions_list, 5)

# Check that the 'language_model.' prefix was correctly removed
target_names = [actions[0].target_name for actions in actions_list]
self.assertIn(
'params.lm.softmax.logits_ffn.w', target_names
)
self.assertIn(
'params.lm.transformer.x_layers_0.pre_layer_norm.scale', target_names
)
self.assertIn(
'params.lm.transformer.x_layers_0.ff_layer.ffn_layer2.w', target_names
)
self.assertIn(
'params.lm.transformer.x_layers_0.self_attention.post.w', target_names
)
self.assertIn('params.lm.final_ln.scale', target_names)


if __name__ == '__main__':
absltest.main()