@@ -1226,6 +1226,10 @@ def forward(self, image_hidden_states, attention_mask):
12261226 more detail.
12271227 return_dict (`bool`, *optional*):
12281228 Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1229+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
1230+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
1231+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
1232+ the complete sequence length.
12291233"""
12301234
12311235
@@ -1334,6 +1338,7 @@ def forward(
13341338 use_cache : Optional [bool ] = None ,
13351339 output_attentions : Optional [bool ] = None ,
13361340 output_hidden_states : Optional [bool ] = None ,
1341+ cache_position : Optional [torch .LongTensor ] = None ,
13371342 return_dict : Optional [bool ] = None ,
13381343 ) -> Union [Tuple , Idefics2BaseModelOutputWithPast ]:
13391344 output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
@@ -1443,6 +1448,7 @@ def forward(
14431448 use_cache = use_cache ,
14441449 output_attentions = output_attentions ,
14451450 output_hidden_states = output_hidden_states ,
1451+ cache_position = cache_position ,
14461452 return_dict = return_dict ,
14471453 )
14481454
@@ -1527,6 +1533,7 @@ def forward(
15271533 output_attentions : Optional [bool ] = None ,
15281534 output_hidden_states : Optional [bool ] = None ,
15291535 return_dict : Optional [bool ] = None ,
1536+ cache_position : Optional [torch .LongTensor ] = None ,
15301537 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
15311538 ) -> Union [Tuple , Idefics2CausalLMOutputWithPast ]:
15321539 r"""
@@ -1603,6 +1610,7 @@ def forward(
16031610 use_cache = use_cache ,
16041611 output_attentions = output_attentions ,
16051612 output_hidden_states = output_hidden_states ,
1613+ cache_position = cache_position ,
16061614 return_dict = return_dict ,
16071615 )
16081616
@@ -1659,49 +1667,28 @@ def prepare_inputs_for_generation(
16591667 # Overwritten -- there are mutually exclusive inputs (if the logic to make `image_hidden_states` take
16601668 # precedence is moved to the model, we can remove this fn)
16611669
1662- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1663- if past_key_values is not None :
1664- if inputs_embeds is not None : # Exception 1
1665- input_ids = input_ids [:, - cache_position .shape [0 ] :]
1666- elif input_ids .shape [1 ] != cache_position .shape [0 ]:
1667- input_ids = input_ids [:, cache_position ]
1668-
1669- position_ids = kwargs .get ("position_ids" , None )
1670- if attention_mask is not None and position_ids is None :
1671- # create position_ids on the fly for batch generation
1672- position_ids = attention_mask .long ().cumsum (- 1 ) - 1
1673- position_ids .masked_fill_ (attention_mask == 0 , 1 )
1674- if past_key_values :
1675- position_ids = position_ids [:, - input_ids .shape [1 ] :]
1670+ model_inputs = super ().prepare_inputs_for_generation (
1671+ input_ids ,
1672+ past_key_values = past_key_values ,
1673+ attention_mask = attention_mask ,
1674+ inputs_embeds = inputs_embeds ,
1675+ cache_position = cache_position ,
1676+ pixel_values = pixel_values ,
1677+ pixel_attention_mask = pixel_attention_mask ,
1678+ image_hidden_states = image_hidden_states ,
1679+ logits_to_keep = logits_to_keep ,
1680+ ** kwargs ,
1681+ )
16761682
16771683 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1678- # but IDEFICS requires noth ids and embeds to be present
1684+ # but IDEFICS requires both ids and embeds to be present
16791685 if inputs_embeds is not None and cache_position [0 ] == 0 :
1680- model_inputs = {"inputs_embeds" : inputs_embeds , "input_ids" : input_ids }
1681- else :
1682- # The clone here is for the same reason as for `position_ids`.
1683- model_inputs = {"input_ids" : input_ids .clone (memory_format = torch .contiguous_format ), "inputs_embeds" : None }
1684-
1685- if logits_to_keep is not None :
1686- model_inputs ["logits_to_keep" ] = logits_to_keep
1686+ model_inputs ["input_ids" ] = input_ids
16871687
16881688 if image_hidden_states is not None :
1689- pixel_values = None
1690- pixel_attention_mask = None
1691- else :
1692- pixel_values = pixel_values
1693- pixel_attention_mask = pixel_attention_mask
1694- model_inputs .update (
1695- {
1696- "position_ids" : position_ids ,
1697- "past_key_values" : past_key_values ,
1698- "use_cache" : kwargs .get ("use_cache" ),
1699- "attention_mask" : attention_mask ,
1700- "pixel_values" : pixel_values ,
1701- "pixel_attention_mask" : pixel_attention_mask ,
1702- "image_hidden_states" : image_hidden_states ,
1703- }
1704- )
1689+ model_inputs ["pixel_values" ] = None
1690+ model_inputs ["pixel_attention_mask" ] = None
1691+
17051692 return model_inputs
17061693
17071694 def _update_model_kwargs_for_generation (self , outputs , model_kwargs , is_encoder_decoder , ** kwargs ):
0 commit comments