Skip to content

Commit 8282db5

Browse files
molbapprobicheauxArthurZucker
committed
Paligemma causal attention mask (#30967)
* PaliGemma working causal attention * Formatting * Style * Docstrings + remove commented code * Update docstring for PaliGemma Config * PaliGemma - add separator ind to model/labels * Refactor + docstring paligemma processor method * Style * return token type ids when tokenizing labels * use token type ids when building causal mask * add token type ids to tester * remove separator from config * fix style * don't ignore separator * add processor documentation * simplify tokenization * fix causal mask * style * fix label propagation, revert suffix naming * fix style * fix labels tokenization * [run-slow]paligemma * add eos if suffixes are present * [run-slow]paligemma * [run-slow]paligemma * add misssing tokens to fast version * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style * [run-slow]paligemma --------- Co-authored-by: Peter Robicheaux <peter@roboflow.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent e5b788a commit 8282db5

File tree

3 files changed

+113
-47
lines changed

3 files changed

+113
-47
lines changed

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,14 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
282282
self.vocab_size = model_embeds.num_embeddings
283283
return model_embeds
284284

285-
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
285+
def _merge_input_ids_with_image_features(
286+
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
287+
):
286288
_, _, embed_dim = image_features.shape
287289
batch_size, sequence_length = input_ids.shape
290+
dtype, device = inputs_embeds.dtype, inputs_embeds.device
291+
min_dtype = torch.finfo(dtype).min
292+
288293
scaled_image_features = image_features / (self.config.hidden_size**0.5)
289294
final_embedding = torch.zeros(
290295
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
@@ -305,24 +310,43 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
305310
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
306311
)
307312
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
313+
if attention_mask is not None:
314+
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
315+
else:
316+
position_ids = None
308317

309-
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
310-
final_attention_mask_4d = final_attention_mask_4d.float().expand(
311-
-1, self.config.text_config.num_key_value_heads, -1, -1
312-
)
313-
314-
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
315-
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
316-
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
318+
if token_type_ids is not None and labels is not None:
319+
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
320+
target_length = cache_position[-1] + 1
321+
causal_mask = torch.full(
322+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
323+
)
324+
if sequence_length != 1:
325+
causal_mask = torch.triu(causal_mask, diagonal=1)
326+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
327+
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
328+
if attention_mask is not None:
329+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
330+
mask_length = attention_mask.shape[-1]
331+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
332+
# unmask the prefill
333+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
334+
token_type_ids[:, None, None, :] == 0, 0
335+
)
336+
padding_mask = padding_mask == 0
337+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
338+
padding_mask, min_dtype
339+
)
317340

318-
if labels is not None:
319341
final_labels = torch.full(
320342
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
321343
)
322344
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
323345
else:
346+
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
347+
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
324348
final_labels = None
325-
return final_embedding, final_attention_mask_4d, final_labels, position_ids
349+
return final_embedding, causal_mask, final_labels, position_ids
326350

327351
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
328352
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -333,6 +357,7 @@ def forward(
333357
attention_mask: Optional[torch.Tensor] = None,
334358
position_ids: Optional[torch.LongTensor] = None,
335359
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
360+
token_type_ids: Optional[torch.LongTensor] = None,
336361
cache_position: Optional[torch.LongTensor] = None,
337362
inputs_embeds: Optional[torch.FloatTensor] = None,
338363
labels: Optional[torch.LongTensor] = None,
@@ -396,8 +421,10 @@ def forward(
396421
selected_image_feature = image_outputs.last_hidden_state
397422
image_features = self.multi_modal_projector(selected_image_feature)
398423

424+
if cache_position is None:
425+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
399426
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
400-
image_features, inputs_embeds, input_ids, attention_mask, labels
427+
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
401428
)
402429

403430
else:
@@ -486,6 +513,7 @@ def prepare_inputs_for_generation(
486513
cache_position=None,
487514
pixel_values=None,
488515
attention_mask=None,
516+
token_type_ids=None,
489517
**kwargs,
490518
):
491519
past_length = 0
@@ -544,6 +572,7 @@ def prepare_inputs_for_generation(
544572
"use_cache": kwargs.get("use_cache"),
545573
"attention_mask": attention_mask,
546574
"pixel_values": pixel_values,
575+
"token_type_ids": token_type_ids,
547576
}
548577
)
549578
return model_inputs

0 commit comments

Comments
 (0)