@@ -283,9 +283,14 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
283283 self .vocab_size = model_embeds .num_embeddings
284284 return model_embeds
285285
286- def _merge_input_ids_with_image_features (self , image_features , inputs_embeds , input_ids , attention_mask , labels ):
286+ def _merge_input_ids_with_image_features (
287+ self , image_features , inputs_embeds , input_ids , attention_mask , labels , token_type_ids , cache_position
288+ ):
287289 _ , _ , embed_dim = image_features .shape
288290 batch_size , sequence_length = input_ids .shape
291+ dtype , device = inputs_embeds .dtype , inputs_embeds .device
292+ min_dtype = torch .finfo (dtype ).min
293+
289294 scaled_image_features = image_features / (self .config .hidden_size ** 0.5 )
290295 final_embedding = torch .zeros (
291296 batch_size , sequence_length , embed_dim , dtype = inputs_embeds .dtype , device = inputs_embeds .device
@@ -306,24 +311,43 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
306311 image_mask .unsqueeze (- 1 ).expand_as (final_embedding ), scaled_image_features
307312 )
308313 final_embedding = torch .where (pad_mask_expanded , torch .zeros_like (final_embedding ), final_embedding )
314+ if attention_mask is not None :
315+ position_ids = (attention_mask .cumsum (- 1 )).masked_fill_ ((attention_mask == 0 ), 1 )
316+ else :
317+ position_ids = None
309318
310- final_attention_mask_4d = attention_mask .unsqueeze (1 ).unsqueeze (2 ) * attention_mask .unsqueeze (1 ).unsqueeze (- 1 )
311- final_attention_mask_4d = final_attention_mask_4d .float ().expand (
312- - 1 , self .config .text_config .num_key_value_heads , - 1 , - 1
313- )
314-
315- # position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
316- # position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
317- position_ids = (attention_mask .cumsum (- 1 )).masked_fill_ ((attention_mask == 0 ), 1 )
319+ if token_type_ids is not None and labels is not None :
320+ # we are training thus we need to create a full mask on the image + prefix but causal on suffix
321+ target_length = cache_position [- 1 ] + 1
322+ causal_mask = torch .full (
323+ (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = device
324+ )
325+ if sequence_length != 1 :
326+ causal_mask = torch .triu (causal_mask , diagonal = 1 )
327+ causal_mask *= torch .arange (target_length , device = device ) > cache_position .reshape (- 1 , 1 )
328+ causal_mask = causal_mask [None , None , :, :].expand (inputs_embeds .shape [0 ], 1 , - 1 , - 1 )
329+ if attention_mask is not None :
330+ causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
331+ mask_length = attention_mask .shape [- 1 ]
332+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
333+ # unmask the prefill
334+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
335+ token_type_ids [:, None , None , :] == 0 , 0
336+ )
337+ padding_mask = padding_mask == 0
338+ causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
339+ padding_mask , min_dtype
340+ )
318341
319- if labels is not None :
320342 final_labels = torch .full (
321343 (batch_size , sequence_length ), self .config .ignore_index , dtype = input_ids .dtype , device = input_ids .device
322344 )
323345 final_labels = torch .where (input_ids != self .pad_token_id , labels , final_labels )
324346 else :
347+ causal_mask = attention_mask .unsqueeze (1 ).unsqueeze (2 ) * attention_mask .unsqueeze (1 ).unsqueeze (- 1 )
348+ causal_mask = causal_mask .to (dtype ).expand (- 1 , self .config .text_config .num_key_value_heads , - 1 , - 1 )
325349 final_labels = None
326- return final_embedding , final_attention_mask_4d , final_labels , position_ids
350+ return final_embedding , causal_mask , final_labels , position_ids
327351
328352 @add_start_docstrings_to_model_forward (PALIGEMMA_INPUTS_DOCSTRING )
329353 @replace_return_docstrings (output_type = PaliGemmaCausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
@@ -334,6 +358,7 @@ def forward(
334358 attention_mask : Optional [torch .Tensor ] = None ,
335359 position_ids : Optional [torch .LongTensor ] = None ,
336360 past_key_values : Optional [Union [List [torch .FloatTensor ], Cache ]] = None ,
361+ token_type_ids : Optional [torch .LongTensor ] = None ,
337362 cache_position : Optional [torch .LongTensor ] = None ,
338363 inputs_embeds : Optional [torch .FloatTensor ] = None ,
339364 labels : Optional [torch .LongTensor ] = None ,
@@ -397,8 +422,10 @@ def forward(
397422 selected_image_feature = image_outputs .last_hidden_state
398423 image_features = self .multi_modal_projector (selected_image_feature )
399424
425+ if cache_position is None :
426+ cache_position = torch .arange (inputs_embeds .shape [1 ], device = inputs_embeds .device )
400427 inputs_embeds , attention_mask , labels , position_ids = self ._merge_input_ids_with_image_features (
401- image_features , inputs_embeds , input_ids , attention_mask , labels
428+ image_features , inputs_embeds , input_ids , attention_mask , labels , token_type_ids , cache_position
402429 )
403430
404431 else :
@@ -487,6 +514,7 @@ def prepare_inputs_for_generation(
487514 cache_position = None ,
488515 pixel_values = None ,
489516 attention_mask = None ,
517+ token_type_ids = None ,
490518 ** kwargs ,
491519 ):
492520 past_length = 0
@@ -545,6 +573,7 @@ def prepare_inputs_for_generation(
545573 "use_cache" : kwargs .get ("use_cache" ),
546574 "attention_mask" : attention_mask ,
547575 "pixel_values" : pixel_values ,
576+ "token_type_ids" : token_type_ids ,
548577 }
549578 )
550579 return model_inputs
0 commit comments