Skip to content

Commit 5f0340b

Browse files
Add Gemma3p5 Audio Encoder (#6)
* initial commit of Gemma 3.5 scaffold * Fixing param pass through on Gemm3p5RMSNorm * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3.5 * Adds AltUp to Gemma 3.5 * Adding Gemma3p5 overall and text config with vision and audio config placeholders (#3) * Adding gemma3p5 text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3.5 (#3) * Initial Gemm3p5TextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3.5 * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3p5 module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3p5Attention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * Adding audio encoder config * Adds high-level components for Audio Encoder * Implement uniform reducer for Audio Encoder * Adding placeholders for Conformer components in Audio Encoder * Adding placeholders for SubSampleConvProjection components in Audio Encoder * Adding SequenceLayer component placeholders * Implementing Gemma3p5AudioEncoder with nn.Sequential * Implementing Gemma3p5AudioSubSampleConvProjection with nn.Sequential * Implementing Conformer model with SequenceLayers * Use OrderedDict in nn.Sequential initializers * Implements sl.Residual in Torch with nn.Sequential and OrderedDict * Adopting a base SequenceLayer class with default forward() method * Implementing sl.GatedLinearUnit in Torch * Implementing sl.Swish in Torch * Implementing sl.ReLU in Torch * Implementing sl.Scale in Torch * Removing sl.Dropout after tree-shaking * Implementing sl.RMSNorm in Torch with fake shape * Implementing sl.GroupNorm in Torch * Implementing sl.Conv2d in Torch * Implementing sl.Dense in Torch * Removing sl.Delay layers, which act as pass-throughs * Connecting shapes to configs in initializers * Removing sl.Emit * Implementing sl.ExpandDims in Torch * Adding sl.GradientClipping to Torch * Implementing sl.DenseShaped in Torch * Implementing sl.LDPA in Torch * Removing unused sl.CombinedQKVProj class * Fixing erroneous type hint * Implemnenting sl.DepthwiseConv1D in Torch * Implementing sl.MaskInvalid in Torch * Fixes for initialization * Fixes for saving weights * Removing einsums per feedback from HF staff * Removing Sequence Layers idioms from audio encoder * Fixes for reviewer comments * CausalLM conversion script for 4B model * inv_timescales to non-persistent buffer * Addressing audio encoder Attention feedback * Addressing Gemma3p5AudioSSCPConvBlock feedback * Addressing Gemma3p5AudioConformerAttention feedback * Addressing padding feedback * Weights conversion loads audio state dict * Always use vision_config so saving works * Token id updates for configs * Stubs for interleaving audio embs * Addressing reviewer feedback --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
1 parent 3bd3c50 commit 5f0340b

File tree

5 files changed

+1305
-125
lines changed

5 files changed

+1305
-125
lines changed

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,9 @@
570570
("gemma", "GemmaForCausalLM"),
571571
("gemma2", "Gemma2ForCausalLM"),
572572
("gemma3", "Gemma3ForConditionalGeneration"),
573-
("gemma3p5", "Gemma3p5ForCausalLM"),
574573
("gemma3_text", "Gemma3ForCausalLM"),
574+
("gemma3p5", "Gemma3p5ForConditionalGeneration"),
575+
("gemma3p5_text", "Gemma3p5ForCausalLM"),
575576
("git", "GitForCausalLM"),
576577
("glm", "GlmForCausalLM"),
577578
("glm4", "Glm4ForCausalLM"),

src/transformers/models/gemma3p5/configuration_gemma3p5.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -268,20 +268,53 @@ def __init__(
268268

269269

270270
class Gemma3p5AudioConfig(PretrainedConfig):
271-
model_type = "gemma3p5"
271+
model_type = "gemma3p5_audio"
272272

273273
def __init__(
274274
self,
275-
*args,
275+
input_feat_size: int = 80,
276276
hidden_size: int = 1536,
277277
embedding_norm_eps: float = 1e-6,
278-
vocab_size: int = 256_128,
278+
vocab_size: int = 128,
279+
gradient_clipping: float = 10_000_000_000.0,
280+
conf_attention_chunk_size: int = 12,
281+
conf_attention_context_left: int = 13,
282+
conf_attention_context_right: int = 0,
283+
conf_attention_invalid_logits_value: float = -1.0e9,
284+
conf_attention_logit_cap: float = 50.0,
285+
conf_num_attention_heads: int = 8,
286+
conf_num_hidden_layers: int = 12,
287+
conf_conv_kernel_size: int = 5,
288+
conf_positional_bias_size: int = 256,
289+
conf_reduction_factor: int = 4,
290+
conf_residual_weight: float = 0.5,
291+
sscp_conv_channel_size: tuple[int, int] = (128, 32),
292+
sscp_conv_group_norm_eps: float = 1e-3,
293+
sscp_conv_kernel_size: tuple[tuple[int, int], tuple[int, int]] = ((3, 3), (3, 3)),
294+
sscp_conv_stride_size: tuple[tuple[int, int], tuple[int, int]] = ((2, 2), (2, 2)),
279295
**kwargs,
280296
):
281-
super().__init__(*args, **kwargs)
297+
super().__init__(**kwargs)
298+
self.input_feat_size = input_feat_size
282299
self.hidden_size = hidden_size
283300
self.embedding_norm_eps = embedding_norm_eps
284301
self.vocab_size = vocab_size
302+
self.gradient_clipping = gradient_clipping
303+
self.conf_attention_chunk_size = conf_attention_chunk_size
304+
self.conf_attention_context_left = conf_attention_context_left
305+
self.conf_attention_context_right = conf_attention_context_right
306+
self.conf_attention_invalid_logits_value = conf_attention_invalid_logits_value
307+
self.conf_attention_logit_cap = conf_attention_logit_cap
308+
self.conf_num_attention_heads = conf_num_attention_heads
309+
self.conf_num_hidden_layers = conf_num_hidden_layers
310+
self.conf_conv_kernel_size = conf_conv_kernel_size
311+
self.conf_positional_bias_size = conf_positional_bias_size
312+
self.conf_reduction_factor = conf_reduction_factor
313+
self.conf_residual_weight = conf_residual_weight
314+
self.sscp_conv_channel_size = sscp_conv_channel_size
315+
self.sscp_conv_eps = sscp_conv_group_norm_eps
316+
self.sscp_conv_kernel_size = sscp_conv_kernel_size
317+
self.sscp_conv_stride_size = sscp_conv_stride_size
285318

286319

287320
class Gemma3p5VisionConfig(PretrainedConfig):
@@ -369,8 +402,11 @@ def __init__(
369402
audio_soft_tokens_per_image: int = 256,
370403
vision_soft_tokens_per_image: int = 256,
371404
boi_token_id: int = 255_999,
372-
eoi_token_id: int = 256_000,
373-
image_token_id: int = 262_144,
405+
eoi_token_id: int = 262_144,
406+
image_token_id: int = 262_145,
407+
boa_token_id: int = 256_000,
408+
eoa_token_id: int = 262_272,
409+
audio_token_id: int = 262_273,
374410
initializer_range: float = 0.02,
375411
**kwargs,
376412
):
@@ -385,12 +421,14 @@ def __init__(
385421
if isinstance(vision_config, dict):
386422
vision_config = Gemma3p5VisionConfig(**vision_config)
387423
elif vision_config is None:
388-
logger.info("vision_config is None. Vision capabilities will not be used.")
424+
vision_config = Gemma3p5VisionConfig()
425+
logger.info("vision_config is None. Using default Gemma3p5VisionConfig.")
389426

390427
if isinstance(audio_config, dict):
391428
audio_config = Gemma3p5AudioConfig(**audio_config)
392429
elif audio_config is None:
393-
logger.info("audio_config is None. Audio capabilities will not be used.")
430+
audio_config = Gemma3p5AudioConfig()
431+
logger.info("audio_config is None. Using default Gemma3p5AudioConfig.")
394432

395433
self.text_config = text_config
396434
self.vision_config = vision_config
@@ -401,7 +439,10 @@ def __init__(
401439
self.boi_token_id = boi_token_id
402440
self.eoi_token_id = eoi_token_id
403441
self.image_token_id = image_token_id
442+
self.boa_token_id = boa_token_id
443+
self.eoa_token_id = eoa_token_id
444+
self.audio_token_id = audio_token_id
404445
self.initializer_range = initializer_range
405446

406447

407-
__all__ = ["Gemma3p5Config", "Gemma3p5TextConfig"]
448+
__all__ = ["Gemma3p5Config", "Gemma3p5AudioConfig", "Gemma3p5TextConfig", "Gemma3p5VisionConfig"]

src/transformers/models/gemma3p5/convert_gemma3p5_weights.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
1919
python src/transformers/models/gemma3p5/convert_gemma3p5_weights.py \
2020
--variant='gemma3p5_4b' \
21-
--tokenizer_path="$HOME/nano3/checkpoints/tokenizer/gemma3n_cleaned_262144.spiece" \
22-
--checkpoint_path="$HOME/nano3/checkpoints/4b_pt_orbax/" \
23-
--output_path="$HOME/nano3/checkpoints/4b_pt_safetensors/"
21+
--tokenizer_path="$HOME/gemma3p5/checkpoints/tokenizer/gemma3p5-tokenizer.model" \
22+
--checkpoint_path="$HOME/gemma3p5/checkpoints/4b_pt_orbax/" \
23+
--output_path="$HOME/gemma3p5/checkpoints/4b_pt_safetensors/"
2424
"""
2525

2626
from collections.abc import Iterator, Sequence
@@ -39,10 +39,11 @@
3939
Gemma3p5ForConditionalGeneration,
4040
Gemma3ImageProcessor,
4141
Gemma3Processor,
42+
Gemma3p5AudioConfig,
4243
Gemma3p5TextConfig,
44+
Gemma3p5VisionConfig,
4345
GemmaTokenizerFast,
4446
GenerationConfig,
45-
SiglipVisionConfig,
4647
)
4748
from transformers.image_utils import PILImageResampling
4849

@@ -94,6 +95,10 @@
9495

9596
_DTYPES = {"float32", "bfloat16", "float16"}
9697

98+
_AUDIO_ENCODER_PARAMETER = "AudioEncoder/encoder"
99+
_AUDIO_ENCODER_CONFORMER = f"{_AUDIO_ENCODER_PARAMETER}/conformer/stacked_layers"
100+
_AUDIO_ENCODER_SSCP = f"{_AUDIO_ENCODER_PARAMETER}/feature"
101+
97102
_TRANSFORMER_PARAMETER = "transformer"
98103
_TRANSFORMER_ALTUP_PROJ = f"{_TRANSFORMER_PARAMETER}/altup_projection_"
99104
_TRANSFORMER_ALTUP_UNEMB = f"{_TRANSFORMER_PARAMETER}/altup_unembed_projection_"
@@ -104,10 +109,6 @@
104109
_TRANSFORMER_POST_TRAINING_PREFIX = "rlx_networks/policy_network/"
105110
_TRANSFORMER_POST_TRAINING_PREFIX_LEN = len(_TRANSFORMER_POST_TRAINING_PREFIX)
106111

107-
# TODO: ryanmullins - Figure out the vision config
108-
_VISION_CONFIG = {}
109-
110-
111112
_VARIANT_GEMMA_3_2B = "gemma3p5_2b"
112113
_VARIANT_GEMMA_3_4B = "gemma3p5_4b"
113114
_VARIANTS = {
@@ -127,16 +128,25 @@
127128
query_pre_attn_scalar=256,
128129
max_position_embeddings=32_768,
129130
),
130-
vision_config=_VISION_CONFIG,
131+
vision_config=Gemma3p5VisionConfig(),
132+
audio_config=Gemma3p5AudioConfig(),
131133
),
132134
_VARIANT_GEMMA_3_4B: Gemma3p5Config(
133135
text_config=Gemma3p5TextConfig(),
134-
vision_config=_VISION_CONFIG,
136+
vision_config=Gemma3p5VisionConfig(),
137+
audio_config=Gemma3p5AudioConfig(),
135138
),
136139
}
137140

138141
# ==== Flags ====
139142

143+
_AUDIO_DTYPE = flags.DEFINE_enum(
144+
name="audio_dtype",
145+
default="bfloat16",
146+
help="The floating point precision (aka dtype) of the model.",
147+
enum_values=_DTYPES,
148+
)
149+
140150
_CHECKPOINT_PATH = flags.DEFINE_string(
141151
name="checkpoint_path",
142152
default=None,
@@ -190,6 +200,125 @@
190200
)
191201

192202

203+
def convert_audio_encoder_weights(
204+
config: Gemma3p5AudioConfig,
205+
path: str,
206+
param: str,
207+
weights: np.ndarray,
208+
) -> Iterator[tuple[str, np.ndarray]]:
209+
210+
converted_paths: list[str] = []
211+
converted_weights: list[Any] = []
212+
213+
if path.startswith(_AUDIO_ENCODER_CONFORMER):
214+
assert weights.shape[0] == config.conf_num_hidden_layers
215+
216+
for i, matrix in enumerate(weights):
217+
if "fflayer_end" in path:
218+
base = f"audio_tower.conformer.{i}.ffw_layer_end"
219+
220+
if path.endswith("ffn_layer1"):
221+
converted_paths.append(f"{base}.ffw_layer_1.weight")
222+
converted_weights.append(matrix.transpose())
223+
elif path.endswith("ffn_layer2"):
224+
converted_paths.append(f"{base}.ffw_layer_2.weight")
225+
converted_weights.append(matrix.transpose())
226+
elif path.endswith("post_layer_norm"):
227+
converted_paths.append(f"{base}.post_layer_norm.weight")
228+
converted_weights.append(matrix)
229+
elif path.endswith("pre_layer_norm"):
230+
converted_paths.append(f"{base}.pre_layer_norm.weight")
231+
converted_weights.append(matrix)
232+
elif "fflayer_start" in path:
233+
base = f"audio_tower.conformer.{i}.ffw_layer_start"
234+
235+
if path.endswith("ffn_layer1"):
236+
converted_paths.append(f"{base}.ffw_layer_1.weight")
237+
converted_weights.append(matrix.transpose())
238+
elif path.endswith("ffn_layer2"):
239+
converted_paths.append(f"{base}.ffw_layer_2.weight")
240+
converted_weights.append(matrix.transpose())
241+
elif path.endswith("post_layer_norm"):
242+
converted_paths.append(f"{base}.post_layer_norm.weight")
243+
converted_weights.append(matrix)
244+
elif path.endswith("pre_layer_norm"):
245+
converted_paths.append(f"{base}.pre_layer_norm.weight")
246+
converted_weights.append(matrix)
247+
elif path.endswith("final_ln"):
248+
converted_paths.append(f"audio_tower.conformer.{i}.norm.weight")
249+
converted_weights.append(matrix)
250+
elif "lconv" in path:
251+
base = f"audio_tower.conformer.{i}.lconv1d"
252+
253+
if path.endswith("conv_norm"):
254+
converted_paths.append(f"{base}.conv_norm.weight")
255+
converted_weights.append(matrix)
256+
elif path.endswith("depthwise_conv1d"):
257+
converted_paths.append(f"{base}.depthwise_conv1d.weight")
258+
converted_weights.append(matrix.transpose())
259+
elif path.endswith("linear_end"):
260+
converted_paths.append(f"{base}.linear_end.weight")
261+
converted_weights.append(matrix)
262+
elif path.endswith("linear_start"):
263+
converted_paths.append(f"{base}.linear_start.weight")
264+
converted_weights.append(matrix.transpose())
265+
elif path.endswith("ln"):
266+
converted_paths.append(f"{base}.pre_layer_norm.weight")
267+
converted_weights.append(matrix)
268+
elif "trans_atten" in path:
269+
base = f"audio_tower.conformer.{i}.attention"
270+
271+
if param == "per_dim_scale":
272+
converted_paths.append(f"{base}.attn.per_dim_scale")
273+
converted_weights.append(matrix)
274+
275+
if path.endswith("query_key_value_projection"):
276+
converted_paths.extend([
277+
f"{base}.attn.q_proj.weight", f"{base}.attn.k_proj.weight", f"{base}.attn.v_proj.weight"
278+
])
279+
converted_weights.extend([
280+
m.squeeze().reshape(config.hidden_size, config.hidden_size).transpose()
281+
for m in np.split(matrix, 3, axis=1)
282+
])
283+
elif path.endswith("pos_proj"):
284+
converted_paths.append(f"{base}.attn.relative_position_embedding.pos_proj.weight")
285+
converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose())
286+
elif path.endswith("post"):
287+
converted_paths.append(f"{base}.post.weight")
288+
converted_weights.append(matrix.reshape(config.hidden_size, config.hidden_size).transpose())
289+
elif path.endswith("post_norm"):
290+
converted_paths.append(f"{base}.post_norm.weight")
291+
converted_weights.append(matrix)
292+
elif path.endswith("pre_norm"):
293+
converted_paths.append(f"{base}.pre_attn_norm.weight")
294+
converted_weights.append(matrix)
295+
elif path.startswith(_AUDIO_ENCODER_SSCP):
296+
if path.endswith("input_proj"):
297+
converted_paths.append(f"audio_tower.subsample_conv_projection.input_proj_linear.weight")
298+
converted_weights.append(
299+
weights.reshape(config.sscp_conv_channel_size[1] ** 2, config.hidden_size).transpose()
300+
)
301+
elif "norm_" in path:
302+
index = int(path[-1])
303+
converted_paths.extend([
304+
f"audio_tower.subsample_conv_projection.conv_{index}.norm.bias",
305+
f"audio_tower.subsample_conv_projection.conv_{index}.norm.weight",
306+
])
307+
converted_weights.extend([np.zeros_like(weights), weights])
308+
elif "subsampling_" in path:
309+
index = int(path[-1])
310+
converted_paths.append(f"audio_tower.subsample_conv_projection.conv_{index}.conv.weight")
311+
converted_weights.append(weights.transpose())
312+
313+
if (cpl := len(converted_paths)) != (cwl := len(converted_weights)):
314+
raise ValueError(
315+
"The `converted_paths` and `converted_weights` should be the same "
316+
f"length. Got {cpl} and {cwl}, respectively, for {path}."
317+
)
318+
319+
return zip(converted_paths, converted_weights)
320+
321+
193322
def convert_transformer_weights(
194323
config: Gemma3p5TextConfig,
195324
path: str,
@@ -215,7 +344,6 @@ def convert_transformer_weights(
215344
attention_type_index = int(path[_TRANSFORMER_DECODER_BLOCK_LEN])
216345
assert weights.shape[0] == config.num_hidden_layers / config.sliding_window_pattern
217346

218-
219347
for i, matrix in enumerate(weights):
220348
layer_idx = config.sliding_window_pattern * i + attention_type_index
221349
base_path = f"model.layers.{layer_idx}"
@@ -302,7 +430,6 @@ def convert_transformer_weights(
302430
if param == "input_embedding":
303431
converted_paths.append("model.embed_tokens.weight")
304432
converted_weights.append(weights)
305-
# TODO: ryanmullins - support multimodal embedding matrices
306433
elif param == "per_layer_embeddings":
307434
converted_paths.append("model.embed_tokens_per_layer.weight")
308435
converted_weights.append(weights.reshape(
@@ -348,11 +475,31 @@ def update_tree(path: str, weights: np.ndarray, target_dtype: torch.dtype) -> No
348475
)
349476

350477
for (path, param), value in tree.flatten_with_path(ckpt):
351-
if path.startswith(_TRANSFORMER_PARAMETER):
478+
if param == "audio_input_embedding_extra":
479+
update_tree("embed_audio.embedding.weight", value, config.audio_config.torch_dtype)
480+
elif path.endswith("audio_embedding_norm"):
481+
update_tree("embed_audio.hard_embedding_norm.weight", value, config.audio_config.torch_dtype)
482+
elif path.endswith("audio_input_projection"):
483+
update_tree("embed_audio.embedding_projection.weight", value.transpose(), config.audio_config.torch_dtype)
484+
elif path.endswith("audio_soft_embedding_norm"):
485+
update_tree("embed_audio.soft_embedding_norm.weight", value, config.audio_config.torch_dtype)
486+
elif param == "mm_input_embedding_extra":
487+
update_tree("embed_vision.embedding.weight", value, config.vision_config.torch_dtype)
488+
elif path.endswith("mm_hard_embedding_norm"):
489+
update_tree("embed_vision.embedding_norm.weight", value, config.vision_config.torch_dtype)
490+
elif path.endswith("mm_input_projection"):
491+
update_tree(
492+
"embed_vision.embedding_projection.weight", value.transpose(), config.vision_config.torch_dtype
493+
)
494+
elif path.startswith(_TRANSFORMER_PARAMETER):
352495
for path, weights in convert_transformer_weights(config.text_config, path, param, value):
353-
update_tree(path, weights, config.text_config.torch_dtype)
496+
update_tree(f"language_model.{path}", weights, config.text_config.torch_dtype)
497+
elif path.startswith(_AUDIO_ENCODER_PARAMETER):
498+
for path, weights in convert_audio_encoder_weights(config.audio_config, path, param, value):
499+
update_tree(path, weights, config.audio_config.torch_dtype)
500+
354501

355-
hf_tree["lm_head.weight"] = hf_tree["model.embed_tokens.weight"]
502+
hf_tree["language_model.lm_head.weight"] = hf_tree["language_model.model.embed_tokens.weight"]
356503

357504
return hf_tree
358505

@@ -364,6 +511,7 @@ def main(*args):
364511
variant = _VARIANT.value
365512

366513
config = _VARIANTS[variant]
514+
config.audio_config.torch_dtype = getattr(torch, _AUDIO_DTYPE.value)
367515
config.text_config.torch_dtype = getattr(torch, _TRANSFORMER_DTYPE.value)
368516
config.vision_config.torch_dtype = getattr(torch, _VISION_DTYPE.value)
369517
if _INCLUDE_CHAT_TEMPLATE.value:
@@ -381,7 +529,7 @@ def main(*args):
381529
logging.info("Converted Gemma 3 (%s) state tree from Orbax to Hugging Face.", variant)
382530

383531
with accelerate.init_empty_weights():
384-
model = Gemma3p5ForCausalLM(config=config.text_config)
532+
model = Gemma3p5ForConditionalGeneration(config=config)
385533

386534
model.load_state_dict(state_tree, assign=True, strict=True)
387535
logging.info(

0 commit comments

Comments
 (0)