-
Notifications
You must be signed in to change notification settings - Fork 31.1k
Open
Labels
WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressbug
Description
System Info
transformersversion: 4.56.1- Platform: Windows-11-10.0.26100-SP0
- Python version: 3.12.11
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.6.2
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu129 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes
- GPU type: NVIDIA GeForce RTX 3060 Laptop GPU
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
Prefix caching with models that use Gemma3ForConditionalGeneration is impossible when the non-cached prompt has an image due to code in src/transformers/models/gemma3/modeling_gemma3.py:1175
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_valuesStep to reproduce:
import copy
from typing import Any, Dict, List
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
model_id = f"google/medgemma-4b-it"
model = AutoModelForImageTextToText.from_pretrained(
model_id, dtype=torch.bfloat16, device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)
# Generate prefix cache
system_instruction = "You are a helpful assistant.\n\nDescribe the following image."
conversation = [
{"role": "system", "content": [{"type": "text", "text": system_instruction}]},
{"role": "user", "content": ""},
]
tokens = processor.apply_chat_template(
conversation, add_generation_prompt=False, tokenize=True, return_tensors="pt"
)
eot_id = model.config.eos_token_id[-1]
eot_pos = torch.nonzero(tokens == eot_id).max()
initial_prompt = tokens[:, :eot_pos]
with torch.no_grad():
prompt_cache = model(input_ids=initial_prompt.to(model.device), use_cache=True).past_key_values
# Inference with PREFIX + USER PROMPT with image
def build_conversation(system_instruction: str, url: str) -> List[Dict[str, Any]]:
"""Builds chat messages with a system instruction and a single image."""
return [
{"role": "system", "content": [{"type": "text", "text": system_instruction}]},
{"role": "user", "content": [{"type": "image", "url": url}]},
]
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
conversation = build_conversation(system_instruction, image_url)
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
with torch.inference_mode():
batch_inputs = inputs.to(model.device, dtype=model.dtype)
past_key_values = copy.deepcopy(prompt_cache)
generation = model.generate(
**batch_inputs,
past_key_values=past_key_values,
max_new_tokens=100,
do_sample=False,
)After caching a prefix, when we pass a user turn that includes image tokens at non-zero cache_position, prepare_inputs_for_generation drops pixel_values (it only forwards them when cache_position[0] == 0).
Expected behavior
prepare_inputs_for_generationshould forwardpixel_valueswhenever the currentinput_idsslice contains the model’s image special token(s), regardless of whethercache_position[0]is 0 or >0. I.e.cache_position[0] == 0is not a good way to understand if we are in the decoding stage.- This enables a common workflow: cache a text system prompt once, then later process a user turn with an image without re-encoding the system text, while still exercising the vision path.
- A minimal change would be to replace the strict
cache_position[0] == 0gate with a check on image token presence- If
(input_ids == config.image_token_index).any():→ passpixel_values. - Else, omit
pixel_valuesduring decode.
- If
Metadata
Metadata
Assignees
Labels
WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressbug