Skip to content

Commit b914f6c

Browse files
committed
squash
1 parent e18f233 commit b914f6c

File tree

12 files changed

+277
-450
lines changed

12 files changed

+277
-450
lines changed

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
add_start_docstrings_to_model_forward,
4242
is_flash_attn_2_available,
4343
is_flash_attn_greater_or_equal_2_10,
44-
is_torchdynamo_compiling,
4544
logging,
4645
replace_return_docstrings,
4746
)
@@ -1651,54 +1650,23 @@ def prepare_inputs_for_generation(
16511650
):
16521651
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
16531652

1654-
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1655-
# Exception 1: when passing input_embeds, input_ids may be missing entries
1656-
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1657-
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
1658-
# (we can't check exception 3 while compiling)
1659-
# Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
1660-
# generate the first token for each sequence. Later use the generated Input ids for continuation.
1661-
if past_key_values is not None:
1662-
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
1663-
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
1664-
elif (
1665-
inputs_embeds is not None # Exception 1
1666-
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
1667-
):
1668-
input_ids = input_ids[:, -cache_position.shape[0] :]
1669-
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1670-
input_ids = input_ids[:, cache_position]
1671-
1672-
if attention_mask is not None and position_ids is None:
1673-
# create position_ids on the fly for batch generation
1674-
position_ids = attention_mask.long().cumsum(-1) - 1
1675-
position_ids.masked_fill_(attention_mask == 0, 1)
1676-
if past_key_values:
1677-
if inputs_embeds is not None and input_ids.shape[1] == 0:
1678-
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
1679-
else:
1680-
position_ids = position_ids[:, -input_ids.shape[1] :]
1681-
1682-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1683-
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
1684-
model_inputs = {"inputs_embeds": inputs_embeds}
1685-
else:
1686-
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
1653+
model_inputs = super().prepare_inputs_for_generation(
1654+
input_ids,
1655+
pixel_values=pixel_values,
1656+
past_key_values=past_key_values,
1657+
attention_mask=attention_mask,
1658+
inputs_embeds=inputs_embeds,
1659+
cache_position=cache_position,
1660+
position_ids=position_ids,
1661+
use_cache=use_cache,
1662+
**kwargs,
1663+
)
16871664

1688-
if cache_position[0] == 0:
1665+
if cache_position[0] != 0:
16891666
# If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
16901667
# Otherwise we need pixel values to be passed to model
1691-
model_inputs["pixel_values"] = pixel_values
1692-
1693-
model_inputs.update(
1694-
{
1695-
"position_ids": position_ids,
1696-
"cache_position": cache_position,
1697-
"past_key_values": past_key_values,
1698-
"use_cache": use_cache,
1699-
"attention_mask": attention_mask,
1700-
}
1701-
)
1668+
model_inputs["pixel_values"] = None
1669+
17021670
return model_inputs
17031671

17041672

src/transformers/models/emu3/modeling_emu3.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,5 +1967,36 @@ def forward(
19671967

19681968
return outputs
19691969

1970+
def prepare_inputs_for_generation(
1971+
self,
1972+
input_ids,
1973+
past_key_values=None,
1974+
attention_mask=None,
1975+
inputs_embeds=None,
1976+
cache_position=None,
1977+
position_ids=None,
1978+
use_cache=True,
1979+
pixel_values=None,
1980+
**kwargs,
1981+
):
1982+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1983+
1984+
model_inputs = super().prepare_inputs_for_generation(
1985+
input_ids,
1986+
past_key_values=past_key_values,
1987+
attention_mask=attention_mask,
1988+
inputs_embeds=inputs_embeds,
1989+
cache_position=cache_position,
1990+
position_ids=position_ids,
1991+
pixel_values=pixel_values,
1992+
use_cache=use_cache,
1993+
**kwargs,
1994+
)
1995+
1996+
if cache_position[0] != 0:
1997+
model_inputs["pixel_values"] = None
1998+
1999+
return model_inputs
2000+
19702001

19712002
__all__ = ["Emu3ForConditionalGeneration", "Emu3ForCausalLM", "Emu3TextModel", "Emu3PreTrainedModel", "Emu3VQVAE"]

src/transformers/models/emu3/modular_emu3.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,37 @@ def forward(
12751275

12761276
return outputs
12771277

1278+
def prepare_inputs_for_generation(
1279+
self,
1280+
input_ids,
1281+
past_key_values=None,
1282+
attention_mask=None,
1283+
inputs_embeds=None,
1284+
cache_position=None,
1285+
position_ids=None,
1286+
use_cache=True,
1287+
pixel_values=None,
1288+
**kwargs,
1289+
):
1290+
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1291+
1292+
model_inputs = super().prepare_inputs_for_generation(
1293+
input_ids,
1294+
past_key_values=past_key_values,
1295+
attention_mask=attention_mask,
1296+
inputs_embeds=inputs_embeds,
1297+
cache_position=cache_position,
1298+
position_ids=position_ids,
1299+
pixel_values=pixel_values,
1300+
use_cache=use_cache,
1301+
**kwargs,
1302+
)
1303+
1304+
if cache_position[0] != 0:
1305+
model_inputs["pixel_values"] = None
1306+
1307+
return model_inputs
1308+
12781309

12791310
__all__ = [
12801311
"Emu3ForConditionalGeneration",

src/transformers/models/fuyu/modeling_fuyu.py

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -345,36 +345,20 @@ def prepare_inputs_for_generation(
345345
):
346346
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
347347

348-
if past_key_values is not None:
349-
input_ids = input_ids[:, -1:]
350-
351-
position_ids = kwargs.get("position_ids", None)
352-
if attention_mask is not None and position_ids is None:
353-
# create position_ids on the fly for batch generation
354-
position_ids = attention_mask.long().cumsum(-1) - 1
355-
position_ids.masked_fill_(attention_mask == 0, 1)
356-
if past_key_values:
357-
position_ids = position_ids[:, -1:]
358-
359-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
360-
if inputs_embeds is not None and past_key_values is None:
361-
model_inputs = {"inputs_embeds": inputs_embeds}
362-
else:
363-
model_inputs = {"input_ids": input_ids}
364-
365-
if image_patches_indices is not None:
366-
model_inputs["image_patches_indices"] = image_patches_indices
367-
368-
model_inputs.update(
369-
{
370-
"position_ids": position_ids,
371-
"past_key_values": past_key_values,
372-
"use_cache": kwargs.get("use_cache"),
373-
"attention_mask": attention_mask,
374-
"image_patches_indices": image_patches_indices if past_key_values is None else None,
375-
"image_patches": image_patches if past_key_values is None else None,
376-
}
348+
model_inputs = super().prepare_inputs_for_generation(
349+
input_ids,
350+
past_key_values=past_key_values,
351+
attention_mask=attention_mask,
352+
inputs_embeds=inputs_embeds,
353+
image_patches=image_patches,
354+
image_patches_indices=image_patches_indices,
355+
**kwargs,
377356
)
357+
358+
if past_key_values is not None:
359+
model_inputs["image_patches_indices"] = None
360+
model_inputs["image_patches"] = None
361+
378362
return model_inputs
379363

380364
@staticmethod

src/transformers/models/idefics/modeling_idefics.py

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,63 +1667,33 @@ def prepare_inputs_for_generation(
16671667
):
16681668
# Overwritten -- custom processing based on `config.use_resampler`
16691669

1670-
model_inputs = {}
1670+
images_kwargs = {}
16711671
if image_hidden_states is not None:
16721672
if self.config.use_resampler:
1673-
model_inputs["perceiver_embeddings"] = image_hidden_states
1673+
images_kwargs["perceiver_embeddings"] = image_hidden_states
16741674
else:
1675-
model_inputs["image_encoder_embeddings"] = image_hidden_states
1675+
images_kwargs["image_encoder_embeddings"] = image_hidden_states
16761676
else:
1677-
model_inputs["pixel_values"] = pixel_values
1678-
1679-
# If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens
1680-
if past_key_values is not None:
1681-
if inputs_embeds is not None:
1682-
if input_ids.shape[1] == 0:
1683-
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
1684-
else:
1685-
input_ids = input_ids[:, -cache_position.shape[0] :]
1686-
elif input_ids.shape[1] != cache_position.shape[0]:
1687-
input_ids = input_ids[:, cache_position]
1688-
if image_attention_mask is not None:
1689-
image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
1677+
images_kwargs["pixel_values"] = pixel_values
1678+
images_kwargs["interpolate_pos_encoding"] = kwargs.pop("interpolate_pos_encoding", False)
16901679

1691-
if attention_mask is not None and position_ids is None:
1692-
# create position_ids on the fly for batch generation
1693-
position_ids = attention_mask.long().cumsum(-1) - 1
1694-
position_ids.masked_fill_(attention_mask == 0, 1)
1695-
1696-
# If past_key_values are present then slice the postion ids for only only the unprocessed tokens.
1697-
if past_key_values:
1698-
if inputs_embeds is not None and input_ids.shape[1] == 0:
1699-
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
1700-
else:
1701-
position_ids = position_ids[:, -input_ids.shape[1] :]
1702-
1703-
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1704-
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1705-
1706-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1707-
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
1708-
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
1709-
else:
1710-
# The clone here is for the same reason as for `position_ids`.
1711-
model_inputs.update(
1712-
{"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1713-
)
1714-
1715-
model_inputs.update(
1716-
{
1717-
"past_key_values": past_key_values,
1718-
"use_cache": use_cache,
1719-
"cache_position": cache_position,
1720-
"position_ids": position_ids,
1721-
"attention_mask": attention_mask,
1722-
"image_attention_mask": image_attention_mask,
1723-
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
1724-
}
1680+
model_inputs = super().prepare_inputs_for_generation(
1681+
input_ids,
1682+
past_key_values=past_key_values,
1683+
attention_mask=attention_mask,
1684+
inputs_embeds=inputs_embeds,
1685+
cache_position=cache_position,
1686+
position_ids=position_ids,
1687+
use_cache=use_cache,
1688+
image_attention_mask=image_attention_mask,
1689+
**images_kwargs,
1690+
**kwargs,
17251691
)
17261692

1693+
if image_attention_mask is not None and inputs_embeds is None:
1694+
seq_length = model_inputs["input_ids"].shape[1]
1695+
model_inputs["image_attention_mask"] = image_attention_mask[:, -seq_length:]
1696+
17271697
return model_inputs
17281698

17291699
def _update_model_kwargs_for_generation(

src/transformers/models/idefics2/modeling_idefics2.py

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,6 +1226,10 @@ def forward(self, image_hidden_states, attention_mask):
12261226
more detail.
12271227
return_dict (`bool`, *optional*):
12281228
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1229+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1230+
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1231+
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1232+
the complete sequence length.
12291233
"""
12301234

12311235

@@ -1334,6 +1338,7 @@ def forward(
13341338
use_cache: Optional[bool] = None,
13351339
output_attentions: Optional[bool] = None,
13361340
output_hidden_states: Optional[bool] = None,
1341+
cache_position: Optional[torch.LongTensor] = None,
13371342
return_dict: Optional[bool] = None,
13381343
) -> Union[Tuple, Idefics2BaseModelOutputWithPast]:
13391344
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1443,6 +1448,7 @@ def forward(
14431448
use_cache=use_cache,
14441449
output_attentions=output_attentions,
14451450
output_hidden_states=output_hidden_states,
1451+
cache_position=cache_position,
14461452
return_dict=return_dict,
14471453
)
14481454

@@ -1527,6 +1533,7 @@ def forward(
15271533
output_attentions: Optional[bool] = None,
15281534
output_hidden_states: Optional[bool] = None,
15291535
return_dict: Optional[bool] = None,
1536+
cache_position: Optional[torch.LongTensor] = None,
15301537
logits_to_keep: Union[int, torch.Tensor] = 0,
15311538
) -> Union[Tuple, Idefics2CausalLMOutputWithPast]:
15321539
r"""
@@ -1603,6 +1610,7 @@ def forward(
16031610
use_cache=use_cache,
16041611
output_attentions=output_attentions,
16051612
output_hidden_states=output_hidden_states,
1613+
cache_position=cache_position,
16061614
return_dict=return_dict,
16071615
)
16081616

@@ -1659,49 +1667,28 @@ def prepare_inputs_for_generation(
16591667
# Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
16601668
# precedence is moved to the model, we can remove this fn)
16611669

1662-
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1663-
if past_key_values is not None:
1664-
if inputs_embeds is not None: # Exception 1
1665-
input_ids = input_ids[:, -cache_position.shape[0] :]
1666-
elif input_ids.shape[1] != cache_position.shape[0]:
1667-
input_ids = input_ids[:, cache_position]
1668-
1669-
position_ids = kwargs.get("position_ids", None)
1670-
if attention_mask is not None and position_ids is None:
1671-
# create position_ids on the fly for batch generation
1672-
position_ids = attention_mask.long().cumsum(-1) - 1
1673-
position_ids.masked_fill_(attention_mask == 0, 1)
1674-
if past_key_values:
1675-
position_ids = position_ids[:, -input_ids.shape[1] :]
1670+
model_inputs = super().prepare_inputs_for_generation(
1671+
input_ids,
1672+
past_key_values=past_key_values,
1673+
attention_mask=attention_mask,
1674+
inputs_embeds=inputs_embeds,
1675+
cache_position=cache_position,
1676+
pixel_values=pixel_values,
1677+
pixel_attention_mask=pixel_attention_mask,
1678+
image_hidden_states=image_hidden_states,
1679+
logits_to_keep=logits_to_keep,
1680+
**kwargs,
1681+
)
16761682

16771683
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1678-
# but IDEFICS requires noth ids and embeds to be present
1684+
# but IDEFICS requires both ids and embeds to be present
16791685
if inputs_embeds is not None and cache_position[0] == 0:
1680-
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": input_ids}
1681-
else:
1682-
# The clone here is for the same reason as for `position_ids`.
1683-
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
1684-
1685-
if logits_to_keep is not None:
1686-
model_inputs["logits_to_keep"] = logits_to_keep
1686+
model_inputs["input_ids"] = input_ids
16871687

16881688
if image_hidden_states is not None:
1689-
pixel_values = None
1690-
pixel_attention_mask = None
1691-
else:
1692-
pixel_values = pixel_values
1693-
pixel_attention_mask = pixel_attention_mask
1694-
model_inputs.update(
1695-
{
1696-
"position_ids": position_ids,
1697-
"past_key_values": past_key_values,
1698-
"use_cache": kwargs.get("use_cache"),
1699-
"attention_mask": attention_mask,
1700-
"pixel_values": pixel_values,
1701-
"pixel_attention_mask": pixel_attention_mask,
1702-
"image_hidden_states": image_hidden_states,
1703-
}
1704-
)
1689+
model_inputs["pixel_values"] = None
1690+
model_inputs["pixel_attention_mask"] = None
1691+
17051692
return model_inputs
17061693

17071694
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):

0 commit comments

Comments
 (0)