Skip to content

Commit fb8e8de

Browse files
RyanMullinsSindhuRaghuram97pculliton
committed
Audio Integration (#12)
* initial commit of Gemma 3.5 scaffold * Fixing param pass through on Gemm3p5RMSNorm * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3.5 * Adds AltUp to Gemma 3.5 * Adding Gemma3p5 overall and text config with vision and audio config placeholders (#3) * Adding gemma3p5 text configs * Adding audio config placeholders * Adding a placeholder for vision configs * Updating MobileNetVisionConfig, inheriting TimmWrapperConfig * Updating text configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Removing altup configs to accept the suggested configs * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating altup config * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Addressing review comments and updating text configs * Adding a config for activation sparsity * Updating configs to pass through options to super class init and adjust some name prefixes * Updating laurel and altup with corrected config values * Normalizing sub_config initializers --------- Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Updating MLP with activation sparsity (#2) * Updating DecoderBlock for Gemma 3.5 (#3) * Initial Gemm3p5TextModel (#4) NOTE: This implementation WILL CHANGE in the coming weeks, however, changes will be strictly additive and this will remain a suitable baseline for downstream implementations to reference. * Adding KV Cache Sharing * Adds Einsum layer to Gemma 3.5 * Updating EinsumLayer API * Refactored kv cache sharing in attention * Adding KVStore for cache sharing * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Update src/transformers/cache_utils.py Co-authored-by: Ryan Mullins <ryanmullins@google.com> * Undoing erroneous force push * Reverting RMSNorm to with_scale by default * Adds LAuReL to Gemma 3.5 * Updating KV Cache Sharing implementation * Updating the q and k norm definitions in the attention module * Fixing name error for q,k,v RMS norm to use the right 3p5 module * Updating MLP with activation sparsity * Updating DecoderBlock for Gemma 3.5 * Updating kv cache sharing implementation with the use of a cache buffer and refactoring some lines of code * Isolating KV Cache logic to relevant components * Fixing logic error in Gemma3p5Attention.forward * Refactoring caching contributions and fixing kv_store initialization * Simplifying Configs * Remove errant self from super init call * Bug fix in the Attention module - changing self.head_dim to config.head_dim * Bug fixes in the LaurelBlock and RMS Norm super init call * removing redundant code from a merge * Adding per_layer_inputs to TextModel * Adding preprocess embeddings with altup * Adds per-layer-to-single output and a host of TODOs * Integrating altup predict with the model workflow and other minor bug fixes * Using nn.Embedding temporarily for text model * It goes forward * Minor refactor of attention sparsity and RoPE initialization * Fixing duplicate rope_scaling param bug when loading from pretrained --------- Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Normalizing on altup_num_inputs config option * Adding audio encoder config * Adds high-level components for Audio Encoder * Implement uniform reducer for Audio Encoder * Adding placeholders for Conformer components in Audio Encoder * Adding placeholders for SubSampleConvProjection components in Audio Encoder * Adding SequenceLayer component placeholders * Implementing Gemma3p5AudioEncoder with nn.Sequential * Implementing Gemma3p5AudioSubSampleConvProjection with nn.Sequential * Implementing Conformer model with SequenceLayers * Use OrderedDict in nn.Sequential initializers * Implements sl.Residual in Torch with nn.Sequential and OrderedDict * Adopting a base SequenceLayer class with default forward() method * Implementing sl.GatedLinearUnit in Torch * Implementing sl.Swish in Torch * Implementing sl.ReLU in Torch * Implementing sl.Scale in Torch * Removing sl.Dropout after tree-shaking * Implementing sl.RMSNorm in Torch with fake shape * Implementing sl.GroupNorm in Torch * Implementing sl.Conv2d in Torch * Implementing sl.Dense in Torch * Removing sl.Delay layers, which act as pass-throughs * Connecting shapes to configs in initializers * Removing sl.Emit * Implementing sl.ExpandDims in Torch * Adding sl.GradientClipping to Torch * Implementing sl.DenseShaped in Torch * Implementing sl.LDPA in Torch * Removing unused sl.CombinedQKVProj class * Fixing erroneous type hint * Implemnenting sl.DepthwiseConv1D in Torch * Implementing sl.MaskInvalid in Torch * Fixes for initialization * Fixes for saving weights * Removing einsums per feedback from HF staff * Removing Sequence Layers idioms from audio encoder * Fixes for reviewer comments * Converting sl.Frontend to FeatureExtractor * Updates for ConditionalGeneration.get_image_features * Adding a WIP draft of image_processing_gemma3p5.py * Update src/transformers/models/gemma3p5/modular_gemma3p5.py Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> * Modular conversion after github suggested change * Text + image gives good results * Fixing image size preset * Draft of audio data in chat template * Removing image processing. Using SigLIP instead. * Audio input going end-to-end * Fixing dtype issues in audio encoder * x-lib formatting consistency * Adding example data * Save preprocessor_config.json from conversion script * Instrumentaiton for debugging * Additional instrumentation for preprocessing debugging * Updates to preprocessor, padding; produces correct end-to-end results on sample * Tackling configuraiton TODOs * Start of feature extractor refatcor * Adds Numpy version of USM extractor, removes Torch version and dependencies * Fixing AltUp.correct coef permute * Supporting batches of single audio segment inputs * Docstrings updates for config * In-lining audio feature extraction * Adjustments to conversion script and smoke test script --------- Co-authored-by: SindhuRaghuram97 <114270661+SindhuRaghuram97@users.noreply.github.com> Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com> Co-authored-by: pculliton <phillipculliton@gmail.com>
1 parent 84ab795 commit fb8e8de

13 files changed

+2070
-1715
lines changed

cat.jpeg

88.5 KB
Loading

gemma3n_forward_test.py

Lines changed: 91 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,94 @@
11
import numpy as np
22
import torch
3-
from transformers import (
4-
AutoModel,
5-
AutoModelForCausalLM,
6-
AutoModelForImageTextToText,
7-
AutoTokenizer,
8-
Gemma3ImageProcessorFast,
9-
Gemma3Processor,
10-
model_addition_debugger_context,
11-
)
3+
from transformers import AutoModelForImageTextToText, AutoProcessor
124

13-
model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g348_safetensors"
5+
model_id = "gg-hf-gm/gemma-3n-E4B-it"
146

15-
image_processor = Gemma3ImageProcessorFast(size={"height": 768, "width": 768})
16-
tokenizer = AutoTokenizer.from_pretrained(model_id)
17-
processor = Gemma3Processor(
18-
tokenizer=tokenizer,
19-
image_processor=image_processor,
20-
chat_template=tokenizer.chat_template,
21-
)
7+
processor = AutoProcessor.from_pretrained(model_id)
228

239
messages = [
24-
{
25-
"role": "user",
26-
"content": [
27-
{"type": "image", "image": "/usr/local/google/home/ryanmullins/Downloads/cat.jpeg"},
28-
{"type": "text", "text": "Describe this image in detail."}
29-
]
30-
}
10+
# {
11+
# "role": "user",
12+
# "content": [
13+
# {"type": "text", "text": "What is the capital of France?"}
14+
# ]
15+
# }
16+
# {
17+
# "role": "user",
18+
# "content": [
19+
# {"type": "image", "image": "cat.jpeg"},
20+
# {"type": "text", "text": "Describe this image in detail."}
21+
# ]
22+
# }
23+
# {
24+
# "role": "user",
25+
# "content": [
26+
# {"type": "text", "text": "Transcribe the following speech segment in English:"},
27+
# {"type": "audio", "audio": "speech.wav"},
28+
# # Send a text to Mike. I'll be home late tomorrow.
29+
# {"type": "audio", "audio": "speech2.wav"},
30+
# ]
31+
# }
32+
# {
33+
# "role": "user",
34+
# "content": [
35+
# {"type": "text", "text": "What is the capital of France?"}
36+
# ]
37+
# }
38+
# [
39+
# {
40+
# "role": "user",
41+
# "content": [
42+
# {"type": "text", "text": "What is the capital of France?"}
43+
# ]
44+
# }
45+
# ],
46+
# [
47+
# {
48+
# "role": "user",
49+
# "content": [
50+
# {"type": "text", "text": "What is the capital of France?"}
51+
# ]
52+
# }
53+
# ],
54+
# [
55+
# {
56+
# "role": "user",
57+
# "content": [
58+
# {"type": "image", "image": "cat.jpeg"},
59+
# {"type": "text", "text": "Describe this image in detail."}
60+
# ]
61+
# }
62+
# ],
63+
# [
64+
# {
65+
# "role": "user",
66+
# "content": [
67+
# {"type": "image", "image": "cat.jpeg"},
68+
# {"type": "text", "text": "Describe this image in detail."}
69+
# ]
70+
# }
71+
# ],
72+
[
73+
{
74+
"role": "user",
75+
"content": [
76+
{"type": "text", "text": "Transcribe the following speech segment in English:"},
77+
{"type": "audio", "audio": "speech.wav"},
78+
# Send a text to Mike. I'll be home late tomorrow.
79+
]
80+
},
81+
],
82+
[
83+
{
84+
"role": "user",
85+
"content": [
86+
{"type": "text", "text": "Transcribe the following speech segment in English:"},
87+
{"type": "audio", "audio": "speech2.wav"},
88+
# pious means to enter through. Their mouth are very tough and even a sharp
89+
]
90+
},
91+
]
3192
]
3293

3394
inputs = processor.apply_chat_template(
@@ -39,54 +100,15 @@
39100
)
40101
input_len = inputs["input_ids"].shape[-1]
41102

42-
print(inputs)
103+
print(f"{inputs.input_ids.shape=}")
43104

44-
model = AutoModelForImageTextToText.from_pretrained(model_id)
105+
model = AutoModelForImageTextToText.from_pretrained(model_id).to(dtype=torch.bfloat16)
45106
inputs = inputs.to(model.device, dtype=torch.bfloat16)
46107

47108
with torch.inference_mode():
48109
generation = model.generate(**inputs, max_new_tokens=16, do_sample=False)
49-
generation = generation[0][input_len:]
50-
51-
decoded = processor.decode(generation, skip_special_tokens=True)
52-
print(decoded)
53-
54-
# model.to(dtype=torch.bfloat16)
55-
# input_ids = tokenizer("The capitol of France is ", return_tensors="pt")
56-
57-
# with model_addition_debugger_context(
58-
# model=model,
59-
# debug_path="/usr/local/google/home/ryanmullins/nano3/g251_debug",
60-
# do_prune_layers=False,
61-
# use_repr=False,
62-
# ):
63-
# outputs = model.forward(**input_ids)
64-
65-
66-
# model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g251_vision_encoder"
67-
# vision_encoder = AutoModel.from_pretrained(model_id)
68-
# print(type(vision_encoder))
69-
# print(vision_encoder.config)
70-
71-
72-
# model_id = "/usr/local/google/home/ryanmullins/git/gemma-3p5-audio-encoder"
73-
# model = Gemma3p5AudioEncoder.from_pretrained(model_id)
74-
# audio_config = model.config
75-
76-
# batch_size = 1
77-
# seq_len = 80 # Example input sequence length (make it odd to test padding)
78-
# pad_len = 40
79-
80-
# rng = np.random.default_rng(seed=42)
81-
# audio_mel = rng.normal(size=(batch_size, audio_config.input_feat_size, seq_len)).astype(np.float32)
82-
# audio_mel_mask_np = np.zeros((batch_size, seq_len), dtype=bool)
83-
# if seq_len >= pad_len: # Ensure pad_len is not out of bounds
84-
# audio_mel_mask_np[:, -pad_len:] = True # Pad the end
110+
generation = generation[:, input_len:]
111+
print(f"{generation=}")
85112

86-
# with model_addition_debugger_context(
87-
# model=model,
88-
# debug_path="/usr/local/google/home/ryanmullins/nano3/gemma3n_audio_encoder_debug",
89-
# do_prune_layers=False,
90-
# use_repr=False,
91-
# ):
92-
# outputs = model.forward(torch.from_numpy(audio_mel), torch.from_numpy(audio_mel_mask_np))
113+
decoded = processor.batch_decode(generation, skip_special_tokens=True)
114+
print(f"{decoded=}")

speech.wav

134 KB
Binary file not shown.

speech2.wav

533 KB
Binary file not shown.

src/transformers/models/auto/feature_extraction_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
("dpt", "DPTFeatureExtractor"),
6161
("encodec", "EncodecFeatureExtractor"),
6262
("flava", "FlavaFeatureExtractor"),
63+
("gemma3p5", "Gemma3p5AudioFeatureExtractor"),
6364
("glpn", "GLPNFeatureExtractor"),
6465
("granite_speech", "GraniteSpeechFeatureExtractor"),
6566
("groupvit", "CLIPFeatureExtractor"),

src/transformers/models/gemma3p5/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919

2020
if TYPE_CHECKING:
2121
from .configuration_gemma3p5 import *
22+
from .feature_extraction_gemm3p5 import *
2223
from .modeling_gemma3p5 import *
24+
from .processing_gemma3p5 import *
2325
else:
2426
import sys
2527

0 commit comments

Comments
 (0)