Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def prepare_inputs_for_generation(
attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
is_prefill: Optional[bool] = False,
**kwargs,
):
"""
Expand Down Expand Up @@ -620,7 +621,7 @@ def prepare_inputs_for_generation(
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
if not self.config.is_encoder_decoder:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
if inputs_embeds is not None and is_prefill:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
Expand Down Expand Up @@ -700,6 +701,7 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
position_ids=position_ids,
token_type_ids=token_type_ids,
is_prefill=is_prefill,
)
else:
attention_mask = causal_mask_creation_function(
Expand Down Expand Up @@ -3838,7 +3840,7 @@ def _assisted_decoding(
def _prefill(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, model_kwargs):
if generation_config.prefill_chunk_size is None:
model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, is_prefill=True, **model_kwargs)
return self(**model_inputs, return_dict=True)
else: # Chunked prefill
# Even if we are not compiling the forward, flex is always compiled when used. With chunked prefill, we may
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
Expand All @@ -1229,10 +1230,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
Expand All @@ -1499,10 +1500,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -505,10 +506,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -1135,12 +1136,13 @@ def prepare_inputs_for_generation(
cache_position=cache_position,
position_ids=position_ids,
use_cache=use_cache,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] != 0:
# If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if not is_prefill:
# If we're in cached decoding stage, pixel values should be `None` because input ids do not
# contain special image token anymore Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = None

return model_inputs
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/clvp/modeling_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,7 @@ def prepare_inputs_for_generation(
inputs_embeds=None,
conditioning_embeds=None,
cache_position=None,
is_prefill=False,
**kwargs,
):
# Overwritten: has `conditioning_embeds`-related logic
Expand All @@ -1315,9 +1316,10 @@ def prepare_inputs_for_generation(
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
is_prefill=is_prefill,
**kwargs,
)
if conditioning_embeds is not None and cache_position[0] != 0:
if conditioning_embeds is not None and not is_prefill:
model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device)

return model_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -412,10 +413,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/deepseek_vl/modeling_deepseek_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def prepare_inputs_for_generation(
inputs_embeds=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- extra custom processing
Expand All @@ -335,12 +336,13 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
if cache_position[0] == 0:
if is_prefill:
model_inputs["pixel_values"] = pixel_values

return model_inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
Expand All @@ -481,10 +482,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
model_inputs = super().prepare_inputs_for_generation(
Expand All @@ -417,10 +418,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,7 @@ def prepare_inputs_for_generation(
position_ids=None,
use_cache=True,
pixel_values=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -1649,10 +1650,11 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
pixel_values=pixel_values,
use_cache=use_cache,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] != 0:
if not is_prefill:
model_inputs["pixel_values"] = None

return model_inputs
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/emu3/modular_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,7 @@ def prepare_inputs_for_generation(
position_ids=None,
use_cache=True,
pixel_values=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -1203,10 +1204,11 @@ def prepare_inputs_for_generation(
position_ids=position_ids,
pixel_values=pixel_values,
use_cache=use_cache,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] != 0:
if not is_prefill:
model_inputs["pixel_values"] = None

return model_inputs
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/florence2/modeling_florence2.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,7 @@ def prepare_inputs_for_generation(
attention_mask=None,
cache_position=None,
logits_to_keep=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -975,10 +976,11 @@ def prepare_inputs_for_generation(
attention_mask=attention_mask,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] == 0:
if is_prefill:
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model
model_inputs["pixel_values"] = pixel_values
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def prepare_inputs_for_generation(
image_patches=None,
image_patches_indices=None,
cache_position=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
Expand All @@ -394,10 +395,11 @@ def prepare_inputs_for_generation(
image_patches=image_patches,
image_patches_indices=image_patches_indices,
cache_position=cache_position,
is_prefill=is_prefill,
**kwargs,
)

if cache_position[0] != 0:
if not is_prefill:
# set image_patches and image_patches_indices to `None` for decoding stage
model_inputs["image_patches_indices"] = None
model_inputs["image_patches"] = None
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,7 @@ def prepare_inputs_for_generation(
use_cache=True,
logits_to_keep=None,
labels=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
Expand All @@ -1236,12 +1237,13 @@ def prepare_inputs_for_generation(
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
is_prefill=is_prefill,
**kwargs,
)

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
if is_prefill:
model_inputs["pixel_values"] = pixel_values

return model_inputs
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/gemma3/modular_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ def prepare_inputs_for_generation(
use_cache=True,
logits_to_keep=None,
labels=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
Expand All @@ -1079,12 +1080,13 @@ def prepare_inputs_for_generation(
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
is_prefill=is_prefill,
**kwargs,
)

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
if is_prefill:
model_inputs["pixel_values"] = pixel_values

return model_inputs
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/gemma3n/modeling_gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -2530,6 +2530,7 @@ def prepare_inputs_for_generation(
use_cache=True,
logits_to_keep=None,
labels=None,
is_prefill=False,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
Expand All @@ -2543,13 +2544,14 @@ def prepare_inputs_for_generation(
use_cache=use_cache,
logits_to_keep=logits_to_keep,
token_type_ids=token_type_ids,
is_prefill=is_prefill,
**kwargs,
)

# If we're in cached decoding stage, multimodal inputs should be None because input ids do not contain special
# tokens anymore. Otherwise multimodal inputs should be passed to model.
# NOTE: use_cache=False always needs pixel_values, input_features, and input_features_mask
if cache_position[0] == 0:
if is_prefill:
model_inputs["pixel_values"] = pixel_values
model_inputs["input_features"] = input_features
model_inputs["input_features_mask"] = input_features_mask
Expand Down
Loading
Loading