Skip to content

Commit a25f7d3

Browse files
molbapprobicheauxArthurZucker
authored
Paligemma causal attention mask (#30967)
* PaliGemma working causal attention * Formatting * Style * Docstrings + remove commented code * Update docstring for PaliGemma Config * PaliGemma - add separator ind to model/labels * Refactor + docstring paligemma processor method * Style * return token type ids when tokenizing labels * use token type ids when building causal mask * add token type ids to tester * remove separator from config * fix style * don't ignore separator * add processor documentation * simplify tokenization * fix causal mask * style * fix label propagation, revert suffix naming * fix style * fix labels tokenization * [run-slow]paligemma * add eos if suffixes are present * [run-slow]paligemma * [run-slow]paligemma * add misssing tokens to fast version * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * fix style * [run-slow]paligemma --------- Co-authored-by: Peter Robicheaux <peter@roboflow.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent d44e1ae commit a25f7d3

File tree

3 files changed

+113
-47
lines changed

3 files changed

+113
-47
lines changed

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,14 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m
283283
self.vocab_size = model_embeds.num_embeddings
284284
return model_embeds
285285

286-
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
286+
def _merge_input_ids_with_image_features(
287+
self, image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
288+
):
287289
_, _, embed_dim = image_features.shape
288290
batch_size, sequence_length = input_ids.shape
291+
dtype, device = inputs_embeds.dtype, inputs_embeds.device
292+
min_dtype = torch.finfo(dtype).min
293+
289294
scaled_image_features = image_features / (self.config.hidden_size**0.5)
290295
final_embedding = torch.zeros(
291296
batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
@@ -306,24 +311,43 @@ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, in
306311
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
307312
)
308313
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
314+
if attention_mask is not None:
315+
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
316+
else:
317+
position_ids = None
309318

310-
final_attention_mask_4d = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
311-
final_attention_mask_4d = final_attention_mask_4d.float().expand(
312-
-1, self.config.text_config.num_key_value_heads, -1, -1
313-
)
314-
315-
# position_ids = torch.arange(0, sequence_length, device=input_ids.device).expand(batch_size, -1)
316-
# position_ids = torch.where(input_ids == self.pad_token_id, torch.ones_like(position_ids), position_ids)
317-
position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1)
319+
if token_type_ids is not None and labels is not None:
320+
# we are training thus we need to create a full mask on the image + prefix but causal on suffix
321+
target_length = cache_position[-1] + 1
322+
causal_mask = torch.full(
323+
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
324+
)
325+
if sequence_length != 1:
326+
causal_mask = torch.triu(causal_mask, diagonal=1)
327+
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
328+
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
329+
if attention_mask is not None:
330+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
331+
mask_length = attention_mask.shape[-1]
332+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
333+
# unmask the prefill
334+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
335+
token_type_ids[:, None, None, :] == 0, 0
336+
)
337+
padding_mask = padding_mask == 0
338+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
339+
padding_mask, min_dtype
340+
)
318341

319-
if labels is not None:
320342
final_labels = torch.full(
321343
(batch_size, sequence_length), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
322344
)
323345
final_labels = torch.where(input_ids != self.pad_token_id, labels, final_labels)
324346
else:
347+
causal_mask = attention_mask.unsqueeze(1).unsqueeze(2) * attention_mask.unsqueeze(1).unsqueeze(-1)
348+
causal_mask = causal_mask.to(dtype).expand(-1, self.config.text_config.num_key_value_heads, -1, -1)
325349
final_labels = None
326-
return final_embedding, final_attention_mask_4d, final_labels, position_ids
350+
return final_embedding, causal_mask, final_labels, position_ids
327351

328352
@add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
329353
@replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -334,6 +358,7 @@ def forward(
334358
attention_mask: Optional[torch.Tensor] = None,
335359
position_ids: Optional[torch.LongTensor] = None,
336360
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
361+
token_type_ids: Optional[torch.LongTensor] = None,
337362
cache_position: Optional[torch.LongTensor] = None,
338363
inputs_embeds: Optional[torch.FloatTensor] = None,
339364
labels: Optional[torch.LongTensor] = None,
@@ -397,8 +422,10 @@ def forward(
397422
selected_image_feature = image_outputs.last_hidden_state
398423
image_features = self.multi_modal_projector(selected_image_feature)
399424

425+
if cache_position is None:
426+
cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device)
400427
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
401-
image_features, inputs_embeds, input_ids, attention_mask, labels
428+
image_features, inputs_embeds, input_ids, attention_mask, labels, token_type_ids, cache_position
402429
)
403430

404431
else:
@@ -487,6 +514,7 @@ def prepare_inputs_for_generation(
487514
cache_position=None,
488515
pixel_values=None,
489516
attention_mask=None,
517+
token_type_ids=None,
490518
**kwargs,
491519
):
492520
past_length = 0
@@ -545,6 +573,7 @@ def prepare_inputs_for_generation(
545573
"use_cache": kwargs.get("use_cache"),
546574
"attention_mask": attention_mask,
547575
"pixel_values": pixel_values,
576+
"token_type_ids": token_type_ids,
548577
}
549578
)
550579
return model_inputs

0 commit comments

Comments
 (0)