@@ -282,9 +282,14 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
282282 self .vocab_size = model_embeds .num_embeddings
283283 return model_embeds
284284
285- def _merge_input_ids_with_image_features (self , image_features , inputs_embeds , input_ids , attention_mask , labels ):
285+ def _merge_input_ids_with_image_features (
286+ self , image_features , inputs_embeds , input_ids , attention_mask , labels , token_type_ids , cache_position
287+ ):
286288 _ , _ , embed_dim = image_features .shape
287289 batch_size , sequence_length = input_ids .shape
290+ dtype , device = inputs_embeds .dtype , inputs_embeds .device
291+ min_dtype = torch .finfo (dtype ).min
292+
288293 scaled_image_features = image_features / (self .config .hidden_size ** 0.5 )
289294 final_embedding = torch .zeros (
290295 batch_size , sequence_length , embed_dim , dtype = inputs_embeds .dtype , device = inputs_embeds .device
@@ -305,24 +310,43 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
305310 image_mask .unsqueeze (- 1 ).expand_as (final_embedding ), scaled_image_features
306311 )
307312 final_embedding = torch .where (pad_mask_expanded , torch .zeros_like (final_embedding ), final_embedding )
313+ if attention_mask is not None :
314+ position_ids = (attention_mask .cumsum (- 1 )).masked_fill_ ((attention_mask == 0 ), 1 )
315+ else :
316+ position_ids = None
308317
309- final_attention_mask_4d = attention_mask .unsqueeze (1 ).unsqueeze (2 ) * attention_mask .unsqueeze (1 ).unsqueeze (- 1 )
310- final_attention_mask_4d = final_attention_mask_4d .float ().expand (
311- - 1 , self .config .text_config .num_key_value_heads , - 1 , - 1
312- )
313-
314- # position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
315- # position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
316- position_ids = (attention_mask .cumsum (- 1 )).masked_fill_ ((attention_mask == 0 ), 1 )
318+ if token_type_ids is not None and labels is not None :
319+ # we are training thus we need to create a full mask on the image + prefix but causal on suffix
320+ target_length = cache_position [- 1 ] + 1
321+ causal_mask = torch .full (
322+ (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device
323+ )
324+ if sequence_length != 1 :
325+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
326+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
327+ causal_mask = causal_mask [None , None , :, :].expand (inputs_embeds .shape [0 ], 1 , - 1 , - 1 )
328+ if attention_mask is not None :
329+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
330+ mask_length = attention_mask .shape [- 1 ]
331+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
332+ # unmask the prefill
333+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
334+ token_type_ids [:, None , None , :] == 0 , 0
335+ )
336+ padding_mask = padding_mask == 0
337+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
338+ padding_mask , min_dtype
339+ )
317340
318- if labels is not None :
319341 final_labels = torch .full (
320342 (batch_size , sequence_length ), self .config .ignore_index , dtype = input_ids .dtype , device = input_ids .device
321343 )
322344 final_labels = torch .where (input_ids != self .pad_token_id , labels , final_labels )
323345 else :
346+ causal_mask = attention_mask .unsqueeze (1 ).unsqueeze (2 ) * attention_mask .unsqueeze (1 ).unsqueeze (- 1 )
347+ causal_mask = causal_mask .to (dtype ).expand (- 1 , self .config .text_config .num_key_value_heads , - 1 , - 1 )
324348 final_labels = None
325- return final_embedding , final_attention_mask_4d , final_labels , position_ids
349+ return final_embedding , causal_mask , final_labels , position_ids
326350
327351 @add_start_docstrings_to_model_forward (PALIGEMMA_INPUTS_DOCSTRING )
328352 @replace_return_docstrings (output_type = PaliGemmaCausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
@@ -333,6 +357,7 @@ def forward(
333357 attention_mask : Optional [torch .Tensor ] = None ,
334358 position_ids : Optional [torch .LongTensor ] = None ,
335359 past_key_values : Optional [Union [List [torch .FloatTensor ], Cache ]] = None ,
360+ token_type_ids : Optional [torch .LongTensor ] = None ,
336361 cache_position : Optional [torch .LongTensor ] = None ,
337362 inputs_embeds : Optional [torch .FloatTensor ] = None ,
338363 labels : Optional [torch .LongTensor ] = None ,
@@ -396,8 +421,10 @@ def forward(
396421 selected_image_feature = image_outputs .last_hidden_state
397422 image_features = self .multi_modal_projector (selected_image_feature )
398423
424+ if cache_position is None :
425+ cache_position = torch .arange (inputs_embeds .shape [1 ], device = inputs_embeds .device )
399426 inputs_embeds , attention_mask , labels , position_ids = self ._merge_input_ids_with_image_features (
400- image_features , inputs_embeds , input_ids , attention_mask , labels
427+ image_features , inputs_embeds , input_ids , attention_mask , labels , token_type_ids , cache_position
401428 )
402429
403430 else :
@@ -486,6 +513,7 @@ def prepare_inputs_for_generation(
486513 cache_position = None ,
487514 pixel_values = None ,
488515 attention_mask = None ,
516+ token_type_ids = None ,
489517 ** kwargs ,
490518 ):
491519 past_length = 0
@@ -544,6 +572,7 @@ def prepare_inputs_for_generation(
544572 "use_cache" : kwargs .get ("use_cache" ),
545573 "attention_mask" : attention_mask ,
546574 "pixel_values" : pixel_values ,
575+ "token_type_ids" : token_type_ids ,
547576 }
548577 )
549578 return model_inputs
0 commit comments