Skip to content

Gemma-3: prepare_inputs_for_generation should forward pixel_values based on image token presence, not cache_position==0 #40910

@Simone999

Description

@Simone999

System Info

  • transformers version: 4.56.1
  • Platform: Windows-11-10.0.26100-SP0
  • Python version: 3.12.11
  • Huggingface_hub version: 0.34.4
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.8.0+cu129 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA GeForce RTX 3060 Laptop GPU

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Prefix caching with models that use Gemma3ForConditionalGeneration is impossible when the non-cached prompt has an image due to code in src/transformers/models/gemma3/modeling_gemma3.py:1175

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
    model_inputs["pixel_values"] = pixel_values

Step to reproduce:

import copy
from typing import Any, Dict, List
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor


model_id = f"google/medgemma-4b-it"

model = AutoModelForImageTextToText.from_pretrained(
    model_id, dtype=torch.bfloat16, device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id)


# Generate prefix cache
system_instruction = "You are a helpful assistant.\n\nDescribe the following image."
conversation = [
    {"role": "system", "content": [{"type": "text", "text": system_instruction}]},
    {"role": "user", "content": ""},
]
tokens = processor.apply_chat_template(
    conversation, add_generation_prompt=False, tokenize=True, return_tensors="pt"
)
eot_id = model.config.eos_token_id[-1]
eot_pos = torch.nonzero(tokens == eot_id).max()
initial_prompt = tokens[:, :eot_pos]


with torch.no_grad():
    prompt_cache = model(input_ids=initial_prompt.to(model.device), use_cache=True).past_key_values


# Inference with PREFIX + USER PROMPT with image
def build_conversation(system_instruction: str, url: str) -> List[Dict[str, Any]]:
    """Builds chat messages with a system instruction and a single image."""
    return [
        {"role": "system", "content": [{"type": "text", "text": system_instruction}]},
        {"role": "user", "content": [{"type": "image", "url": url}]},
    ]
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
conversation = build_conversation(system_instruction, image_url)
inputs = processor.apply_chat_template(
    conversation,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt",
)

with torch.inference_mode():
    batch_inputs = inputs.to(model.device, dtype=model.dtype)
    past_key_values = copy.deepcopy(prompt_cache)
    generation = model.generate(
        **batch_inputs,
        past_key_values=past_key_values,
        max_new_tokens=100,
        do_sample=False,
    )

After caching a prefix, when we pass a user turn that includes image tokens at non-zero cache_position, prepare_inputs_for_generation drops pixel_values (it only forwards them when cache_position[0] == 0).

Expected behavior

  • prepare_inputs_for_generation should forward pixel_values whenever the current input_ids slice contains the model’s image special token(s), regardless of whether cache_position[0] is 0 or >0. I.e. cache_position[0] == 0 is not a good way to understand if we are in the decoding stage.
  • This enables a common workflow: cache a text system prompt once, then later process a user turn with an image without re-encoding the system text, while still exercising the vision path.
  • A minimal change would be to replace the strict cache_position[0] == 0 gate with a check on image token presence
    • If (input_ids == config.image_token_index).any(): → pass pixel_values.
    • Else, omit pixel_values during decode.

Metadata

Metadata

Assignees

Labels

WIPLabel your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progressbug

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions