Skip to content

Commit 9ccdc84

Browse files
molbapArthurZucker
authored andcommitted
Paligemma- fix devices and dtype assignments (#31008)
* fix devices and dtype assignments * [run-slow]paligemma
1 parent 12aa316 commit 9ccdc84

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/transformers/models/paligemma/modeling_paligemma.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,14 +300,15 @@ def _merge_input_ids_with_image_features(
300300
pad_mask = input_ids == self.pad_token_id
301301

302302
# expand masks to match embedding dimension
303-
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
304-
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
303+
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
304+
pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim).to(inputs_embeds.device)
305305
# insert padding and text token embeddings
306306
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
307307
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
308308
# insert image embeddings - the image mask is always less or equal to the sentence in length
309309
final_embedding = final_embedding.masked_scatter(
310-
image_mask.unsqueeze(-1).expand_as(final_embedding), scaled_image_features
310+
image_mask.unsqueeze(-1).expand_as(final_embedding).to(device=final_embedding.device),
311+
scaled_image_features.to(device=final_embedding.device, dtype=final_embedding.dtype),
311312
)
312313
final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding)
313314
if attention_mask is not None:
@@ -328,10 +329,12 @@ def _merge_input_ids_with_image_features(
328329
if attention_mask is not None:
329330
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
330331
mask_length = attention_mask.shape[-1]
331-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
332+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
333+
causal_mask.device
334+
)
332335
# unmask the prefill
333336
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
334-
token_type_ids[:, None, None, :] == 0, 0
337+
token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
335338
)
336339
padding_mask = padding_mask == 0
337340
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
@@ -483,7 +486,7 @@ def forward(
483486
# we use the input attention mask to shift the logits and labels, because it is 2D.
484487
shift_attention_mask = input_attention_mask[..., 1:]
485488
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
486-
shift_labels = shift_labels[shift_attention_mask.to(logits.device) != 0].contiguous()
489+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
487490
else:
488491
shift_logits = shift_logits.contiguous()
489492
shift_labels = shift_labels.contiguous()

0 commit comments

Comments
 (0)