|
37 | 37 | import timm
|
38 | 38 |
|
39 | 39 | from transformers import (
|
40 |
| - AutoConfig, |
41 | 40 | Gemma3p5Config,
|
42 |
| - Gemma3p5ForCausalLM, |
43 | 41 | Gemma3p5ForConditionalGeneration,
|
44 |
| - Gemma3ImageProcessor, |
| 42 | + Gemma3ImageProcessorFast, |
45 | 43 | Gemma3Processor,
|
46 | 44 | Gemma3NanoAudioConfig,
|
47 |
| - Gemma3NanoAudioEncoder, |
48 | 45 | Gemma3p5TextConfig,
|
49 | 46 | Gemma3p5VisionConfig,
|
50 | 47 | GemmaTokenizerFast,
|
|
153 | 150 | intermediate_size=2048 * 4,
|
154 | 151 | num_hidden_layers=30,
|
155 | 152 | activation_sparsity_pattern=(0.95,)*10 + (0.0,)*20,
|
| 153 | + num_kv_shared_layers=10, |
156 | 154 | ),
|
157 | 155 | vision_config=Gemma3p5VisionConfig(),
|
158 | 156 | audio_config=Gemma3NanoAudioConfig(),
|
|
182 | 180 | )
|
183 | 181 |
|
184 | 182 | _INCLUDE_CHAT_TEMPLATE = flags.DEFINE_bool(
|
185 |
| - name="include_chat_template", default=False, help="If true, will save the default chat template with the tokenizer" |
| 183 | + name="include_chat_template", default=True, help="If true, will save the default chat template with the tokenizer" |
186 | 184 | )
|
187 | 185 |
|
188 | 186 | _OUTPUT_PATH = flags.DEFINE_string(
|
@@ -641,12 +639,14 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
|
641 | 639 | elif param == "mm_input_embedding_extra":
|
642 | 640 | update_tree("embed_vision.embedding.weight", value, config.vision_config.torch_dtype)
|
643 | 641 | elif path.endswith("mm_hard_embedding_norm"):
|
644 |
| - update_tree("embed_vision.embedding_norm.weight", value, config.vision_config.torch_dtype) |
| 642 | + update_tree("embed_vision.hard_embedding_norm.weight", value, config.vision_config.torch_dtype) |
645 | 643 | elif path.endswith("mm_input_projection"):
|
646 | 644 | update_tree(
|
647 | 645 | "embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
|
648 | 646 | )
|
649 |
| - if path.startswith(_TRANSFORMER_PARAMETER): |
| 647 | + elif path.endswith("mm_soft_embedding_norm"): |
| 648 | + update_tree("embed_vision.soft_embedding_norm.weight", value, config.vision_config.torch_dtype) |
| 649 | + elif path.startswith(_TRANSFORMER_PARAMETER): |
650 | 650 | for path, weights in convert_transformer_weights(config.text_config, path, param, value):
|
651 | 651 | update_tree(f"language_model.{path}", weights, config.text_config.torch_dtype)
|
652 | 652 | elif _MOBILE_NET_PREFIX in path:
|
@@ -720,23 +720,22 @@ def main(*args):
|
720 | 720 | tokenizer.save_pretrained(output_path)
|
721 | 721 | logging.info("Saved GemmaTokenizer for %s to %s", variant, output_path)
|
722 | 722 |
|
723 |
| - # # if variant != _VARIANT_GEMMA_3_2B: |
724 |
| - # # image_processor = Gemma3ImageProcessor( |
725 |
| - # # image_seq_length=256, |
726 |
| - # # image_mean=(0.5,) * 3, |
727 |
| - # # image_std=(0.5,) * 3, |
728 |
| - # # size={"height": 896, "width": 896}, |
729 |
| - # # resample=PILImageResampling.BILINEAR, |
730 |
| - # # ) |
731 |
| - # # processor = Gemma3Processor( |
732 |
| - # # image_processor=image_processor, |
733 |
| - # # tokenizer=tokenizer, |
734 |
| - # # chat_template=tokenizer.chat_template, |
735 |
| - # # ) |
736 |
| - # # processor.save_pretrained(output_path) |
737 |
| - # # logging.info("Saved Gemma3Processor for %s to %s", variant, output_path) |
738 |
| - # # del processor |
| 723 | + image_processor = Gemma3ImageProcessorFast( |
| 724 | + image_seq_length=256, |
| 725 | + image_mean=(0.5,) * 3, |
| 726 | + image_std=(0.5,) * 3, |
| 727 | + size={"height": 768, "width": 768}, |
| 728 | + resample=PILImageResampling.BILINEAR, |
| 729 | + ) |
| 730 | + processor = Gemma3Processor( |
| 731 | + image_processor=image_processor, |
| 732 | + tokenizer=tokenizer, |
| 733 | + chat_template=tokenizer.chat_template, |
| 734 | + ) |
| 735 | + processor.save_pretrained(output_path) |
| 736 | + logging.info("Saved Gemma3Processor for %s to %s", variant, output_path) |
739 | 737 |
|
| 738 | + del processor |
740 | 739 | del tokenizer
|
741 | 740 |
|
742 | 741 | generation_config = GenerationConfig(
|
|
0 commit comments