|
1 | 1 | import numpy as np
|
2 | 2 | import torch
|
3 |
| -from transformers import ( |
4 |
| - AutoModel, |
5 |
| - AutoModelForCausalLM, |
6 |
| - AutoModelForImageTextToText, |
7 |
| - AutoTokenizer, |
8 |
| - Gemma3ImageProcessorFast, |
9 |
| - Gemma3Processor, |
10 |
| - model_addition_debugger_context, |
11 |
| -) |
| 3 | +from transformers import AutoModelForImageTextToText, AutoProcessor |
12 | 4 |
|
13 |
| -model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g348_safetensors" |
| 5 | +model_id = "gg-hf-gm/gemma-3n-E4B-it" |
14 | 6 |
|
15 |
| -image_processor = Gemma3ImageProcessorFast(size={"height": 768, "width": 768}) |
16 |
| -tokenizer = AutoTokenizer.from_pretrained(model_id) |
17 |
| -processor = Gemma3Processor( |
18 |
| - tokenizer=tokenizer, |
19 |
| - image_processor=image_processor, |
20 |
| - chat_template=tokenizer.chat_template, |
21 |
| -) |
| 7 | +processor = AutoProcessor.from_pretrained(model_id) |
22 | 8 |
|
23 | 9 | messages = [
|
24 |
| - { |
25 |
| - "role": "user", |
26 |
| - "content": [ |
27 |
| - {"type": "image", "image": "/usr/local/google/home/ryanmullins/Downloads/cat.jpeg"}, |
28 |
| - {"type": "text", "text": "Describe this image in detail."} |
29 |
| - ] |
30 |
| - } |
| 10 | + # { |
| 11 | + # "role": "user", |
| 12 | + # "content": [ |
| 13 | + # {"type": "text", "text": "What is the capital of France?"} |
| 14 | + # ] |
| 15 | + # } |
| 16 | + # { |
| 17 | + # "role": "user", |
| 18 | + # "content": [ |
| 19 | + # {"type": "image", "image": "cat.jpeg"}, |
| 20 | + # {"type": "text", "text": "Describe this image in detail."} |
| 21 | + # ] |
| 22 | + # } |
| 23 | + # { |
| 24 | + # "role": "user", |
| 25 | + # "content": [ |
| 26 | + # {"type": "text", "text": "Transcribe the following speech segment in English:"}, |
| 27 | + # {"type": "audio", "audio": "speech.wav"}, |
| 28 | + # # Send a text to Mike. I'll be home late tomorrow. |
| 29 | + # {"type": "audio", "audio": "speech2.wav"}, |
| 30 | + # ] |
| 31 | + # } |
| 32 | + # { |
| 33 | + # "role": "user", |
| 34 | + # "content": [ |
| 35 | + # {"type": "text", "text": "What is the capital of France?"} |
| 36 | + # ] |
| 37 | + # } |
| 38 | + # [ |
| 39 | + # { |
| 40 | + # "role": "user", |
| 41 | + # "content": [ |
| 42 | + # {"type": "text", "text": "What is the capital of France?"} |
| 43 | + # ] |
| 44 | + # } |
| 45 | + # ], |
| 46 | + # [ |
| 47 | + # { |
| 48 | + # "role": "user", |
| 49 | + # "content": [ |
| 50 | + # {"type": "text", "text": "What is the capital of France?"} |
| 51 | + # ] |
| 52 | + # } |
| 53 | + # ], |
| 54 | + # [ |
| 55 | + # { |
| 56 | + # "role": "user", |
| 57 | + # "content": [ |
| 58 | + # {"type": "image", "image": "cat.jpeg"}, |
| 59 | + # {"type": "text", "text": "Describe this image in detail."} |
| 60 | + # ] |
| 61 | + # } |
| 62 | + # ], |
| 63 | + # [ |
| 64 | + # { |
| 65 | + # "role": "user", |
| 66 | + # "content": [ |
| 67 | + # {"type": "image", "image": "cat.jpeg"}, |
| 68 | + # {"type": "text", "text": "Describe this image in detail."} |
| 69 | + # ] |
| 70 | + # } |
| 71 | + # ], |
| 72 | + [ |
| 73 | + { |
| 74 | + "role": "user", |
| 75 | + "content": [ |
| 76 | + {"type": "text", "text": "Transcribe the following speech segment in English:"}, |
| 77 | + {"type": "audio", "audio": "speech.wav"}, |
| 78 | + # Send a text to Mike. I'll be home late tomorrow. |
| 79 | + ] |
| 80 | + }, |
| 81 | + ], |
| 82 | + [ |
| 83 | + { |
| 84 | + "role": "user", |
| 85 | + "content": [ |
| 86 | + {"type": "text", "text": "Transcribe the following speech segment in English:"}, |
| 87 | + {"type": "audio", "audio": "speech2.wav"}, |
| 88 | + # pious means to enter through. Their mouth are very tough and even a sharp |
| 89 | + ] |
| 90 | + }, |
| 91 | + ] |
31 | 92 | ]
|
32 | 93 |
|
33 | 94 | inputs = processor.apply_chat_template(
|
|
39 | 100 | )
|
40 | 101 | input_len = inputs["input_ids"].shape[-1]
|
41 | 102 |
|
42 |
| -print(inputs) |
| 103 | +print(f"{inputs.input_ids.shape=}") |
43 | 104 |
|
44 |
| -model = AutoModelForImageTextToText.from_pretrained(model_id) |
| 105 | +model = AutoModelForImageTextToText.from_pretrained(model_id).to(dtype=torch.bfloat16) |
45 | 106 | inputs = inputs.to(model.device, dtype=torch.bfloat16)
|
46 | 107 |
|
47 | 108 | with torch.inference_mode():
|
48 | 109 | generation = model.generate(**inputs, max_new_tokens=16, do_sample=False)
|
49 |
| - generation = generation[0][input_len:] |
50 |
| - |
51 |
| -decoded = processor.decode(generation, skip_special_tokens=True) |
52 |
| -print(decoded) |
53 |
| - |
54 |
| -# model.to(dtype=torch.bfloat16) |
55 |
| -# input_ids = tokenizer("The capitol of France is ", return_tensors="pt") |
56 |
| - |
57 |
| -# with model_addition_debugger_context( |
58 |
| -# model=model, |
59 |
| -# debug_path="/usr/local/google/home/ryanmullins/nano3/g251_debug", |
60 |
| -# do_prune_layers=False, |
61 |
| -# use_repr=False, |
62 |
| -# ): |
63 |
| -# outputs = model.forward(**input_ids) |
64 |
| - |
65 |
| - |
66 |
| -# model_id = "/usr/local/google/home/ryanmullins/nano3/checkpoints/g251_vision_encoder" |
67 |
| -# vision_encoder = AutoModel.from_pretrained(model_id) |
68 |
| -# print(type(vision_encoder)) |
69 |
| -# print(vision_encoder.config) |
70 |
| - |
71 |
| - |
72 |
| -# model_id = "/usr/local/google/home/ryanmullins/git/gemma-3p5-audio-encoder" |
73 |
| -# model = Gemma3p5AudioEncoder.from_pretrained(model_id) |
74 |
| -# audio_config = model.config |
75 |
| - |
76 |
| -# batch_size = 1 |
77 |
| -# seq_len = 80 # Example input sequence length (make it odd to test padding) |
78 |
| -# pad_len = 40 |
79 |
| - |
80 |
| -# rng = np.random.default_rng(seed=42) |
81 |
| -# audio_mel = rng.normal(size=(batch_size, audio_config.input_feat_size, seq_len)).astype(np.float32) |
82 |
| -# audio_mel_mask_np = np.zeros((batch_size, seq_len), dtype=bool) |
83 |
| -# if seq_len >= pad_len: # Ensure pad_len is not out of bounds |
84 |
| -# audio_mel_mask_np[:, -pad_len:] = True # Pad the end |
| 110 | + generation = generation[:, input_len:] |
| 111 | + print(f"{generation=}") |
85 | 112 |
|
86 |
| -# with model_addition_debugger_context( |
87 |
| -# model=model, |
88 |
| -# debug_path="/usr/local/google/home/ryanmullins/nano3/gemma3n_audio_encoder_debug", |
89 |
| -# do_prune_layers=False, |
90 |
| -# use_repr=False, |
91 |
| -# ): |
92 |
| -# outputs = model.forward(torch.from_numpy(audio_mel), torch.from_numpy(audio_mel_mask_np)) |
| 113 | +decoded = processor.batch_decode(generation, skip_special_tokens=True) |
| 114 | +print(f"{decoded=}") |
0 commit comments