Skip to content

Commit 84ab795

Browse files
Text + Vision Part 2 (#23)
* Updates for ConditionalGeneration.get_image_features * Adding a WIP draft of image_processing_gemma3p5.py * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Modular conversion after github suggested change * Text + image gives good results * Fixing image size preset * Updating configs for the 2B variant in the conversion script * Using final generation config in conversion script --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com>
1 parent 7d14788 commit 84ab795

File tree

5 files changed

+453
-148
lines changed

5 files changed

+453
-148
lines changed

gemma3n_forward_test.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,52 @@
55
AutoModelForCausalLM,
66
AutoModelForImageTextToText,
77
AutoTokenizer,
8-
model_addition_debugger_context
8+
Gemma3ImageProcessorFast,
9+
Gemma3Processor,
10+
model_addition_debugger_context,
911
)
1012

11-
model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g251_safetensors"
13+
model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g348_safetensors"
1214

15+
image_processor = Gemma3ImageProcessorFast(size={"height": 768, "width": 768})
1316
tokenizer = AutoTokenizer.from_pretrained(model_id)
14-
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
15-
print(type(model.config))
16-
print(type(model.config.audio_config))
17-
print(type(model.config.text_config))
18-
print(type(model.config.vision_config))
17+
processor = Gemma3Processor(
18+
tokenizer=tokenizer,
19+
image_processor=image_processor,
20+
chat_template=tokenizer.chat_template,
21+
)
22+
23+
messages = [
24+
{
25+
"role": "user",
26+
"content": [
27+
{"type": "image", "image": "/usr/local/google/home/ryanmullins/Downloads/cat.jpeg"},
28+
{"type": "text", "text": "Describe this image in detail."}
29+
]
30+
}
31+
]
32+
33+
inputs = processor.apply_chat_template(
34+
messages,
35+
add_generation_prompt=True,
36+
tokenize=True,
37+
return_dict=True,
38+
return_tensors="pt",
39+
)
40+
input_len = inputs["input_ids"].shape[-1]
41+
42+
print(inputs)
43+
44+
model = AutoModelForImageTextToText.from_pretrained(model_id)
45+
inputs = inputs.to(model.device, dtype=torch.bfloat16)
46+
47+
with torch.inference_mode():
48+
generation = model.generate(**inputs, max_new_tokens=16, do_sample=False)
49+
generation = generation[0][input_len:]
50+
51+
decoded = processor.decode(generation, skip_special_tokens=True)
52+
print(decoded)
53+
1954
# model.to(dtype=torch.bfloat16)
2055
# input_ids = tokenizer("The capitol of France is ", return_tensors="pt")
2156

src/transformers/models/gemma3p5/convert_gemma3p5_weights.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,11 @@
3737
import timm
3838

3939
from transformers import (
40-
AutoConfig,
4140
Gemma3p5Config,
42-
Gemma3p5ForCausalLM,
4341
Gemma3p5ForConditionalGeneration,
44-
Gemma3ImageProcessor,
42+
Gemma3ImageProcessorFast,
4543
Gemma3Processor,
4644
Gemma3NanoAudioConfig,
47-
Gemma3NanoAudioEncoder,
4845
Gemma3p5TextConfig,
4946
Gemma3p5VisionConfig,
5047
GemmaTokenizerFast,
@@ -153,6 +150,7 @@
153150
intermediate_size=2048 * 4,
154151
num_hidden_layers=30,
155152
activation_sparsity_pattern=(0.95,)*10 + (0.0,)*20,
153+
num_kv_shared_layers=10,
156154
),
157155
vision_config=Gemma3p5VisionConfig(),
158156
audio_config=Gemma3NanoAudioConfig(),
@@ -182,7 +180,7 @@
182180
)
183181

184182
_INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
185-
name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer"
183+
name="include_chat_template", default=True, help="If true, will save the default chat template with the tokenizer"
186184
)
187185

188186
_OUTPUT_PATH = flags.DEFINE_string(
@@ -641,12 +639,14 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
641639
elif param == "mm_input_embedding_extra":
642640
update_tree("embed_vision.embedding.weight", value, config.vision_config.torch_dtype)
643641
elif path.endswith("mm_hard_embedding_norm"):
644-
update_tree("embed_vision.embedding_norm.weight", value, config.vision_config.torch_dtype)
642+
update_tree("embed_vision.hard_embedding_norm.weight", value, config.vision_config.torch_dtype)
645643
elif path.endswith("mm_input_projection"):
646644
update_tree(
647645
"embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
648646
)
649-
if path.startswith(_TRANSFORMER_PARAMETER):
647+
elif path.endswith("mm_soft_embedding_norm"):
648+
update_tree("embed_vision.soft_embedding_norm.weight", value, config.vision_config.torch_dtype)
649+
elif path.startswith(_TRANSFORMER_PARAMETER):
650650
for path, weights in convert_transformer_weights(config.text_config, path, param, value):
651651
update_tree(f"language_model.{path}", weights, config.text_config.torch_dtype)
652652
elif _MOBILE_NET_PREFIX in path:
@@ -720,23 +720,22 @@ def main(*args):
720720
tokenizer.save_pretrained(output_path)
721721
logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
722722

723-
# # if variant != _VARIANT_GEMMA_3_2B:
724-
# # image_processor = Gemma3ImageProcessor(
725-
# # image_seq_length=256,
726-
# # image_mean=(0.5,) * 3,
727-
# # image_std=(0.5,) * 3,
728-
# # size={"height": 896, "width": 896},
729-
# # resample=PILImageResampling.BILINEAR,
730-
# # )
731-
# # processor = Gemma3Processor(
732-
# # image_processor=image_processor,
733-
# # tokenizer=tokenizer,
734-
# # chat_template=tokenizer.chat_template,
735-
# # )
736-
# # processor.save_pretrained(output_path)
737-
# # logging.info("Saved Gemma3Processor for %s to %s", variant, output_path)
738-
# # del processor
723+
image_processor = Gemma3ImageProcessorFast(
724+
image_seq_length=256,
725+
image_mean=(0.5,) * 3,
726+
image_std=(0.5,) * 3,
727+
size={"height": 768, "width": 768},
728+
resample=PILImageResampling.BILINEAR,
729+
)
730+
processor = Gemma3Processor(
731+
image_processor=image_processor,
732+
tokenizer=tokenizer,
733+
chat_template=tokenizer.chat_template,
734+
)
735+
processor.save_pretrained(output_path)
736+
logging.info("Saved Gemma3Processor for %s to %s", variant, output_path)
739737

738+
del processor
740739
del tokenizer
741740

742741
generation_config = GenerationConfig(

0 commit comments

Comments
 (0)