@@ -300,14 +300,15 @@ def _merge_input_ids_with_image_features(
300300 pad_mask = input_ids == self .pad_token_id
301301
302302 # expand masks to match embedding dimension
303- text_mask_expanded = text_mask .unsqueeze (- 1 ).expand (- 1 , - 1 , embed_dim )
304- pad_mask_expanded = pad_mask .unsqueeze (- 1 ).expand (- 1 , - 1 , embed_dim )
303+ text_mask_expanded = text_mask .unsqueeze (- 1 ).expand (- 1 , - 1 , embed_dim ). to ( inputs_embeds . device )
304+ pad_mask_expanded = pad_mask .unsqueeze (- 1 ).expand (- 1 , - 1 , embed_dim ). to ( inputs_embeds . device )
305305 # insert padding and text token embeddings
306306 final_embedding = torch .where (text_mask_expanded , inputs_embeds , final_embedding )
307307 final_embedding = torch .where (pad_mask_expanded , torch .zeros_like (final_embedding ), final_embedding )
308308 # insert image embeddings - the image mask is always less or equal to the sentence in length
309309 final_embedding = final_embedding .masked_scatter (
310- image_mask .unsqueeze (- 1 ).expand_as (final_embedding ), scaled_image_features
310+ image_mask .unsqueeze (- 1 ).expand_as (final_embedding ).to (device = final_embedding .device ),
311+ scaled_image_features .to (device = final_embedding .device , dtype = final_embedding .dtype ),
311312 )
312313 final_embedding = torch .where (pad_mask_expanded , torch .zeros_like (final_embedding ), final_embedding )
313314 if attention_mask is not None :
@@ -328,10 +329,12 @@ def _merge_input_ids_with_image_features(
328329 if attention_mask is not None :
329330 causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
330331 mask_length = attention_mask .shape [- 1 ]
331- padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :]
332+ padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (
333+ causal_mask .device
334+ )
332335 # unmask the prefill
333336 causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
334- token_type_ids [:, None , None , :] == 0 , 0
337+ token_type_ids [:, None , None , :]. to ( causal_mask . device ) == 0 , 0
335338 )
336339 padding_mask = padding_mask == 0
337340 causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
@@ -483,7 +486,7 @@ def forward(
483486 # we use the input attention mask to shift the logits and labels, because it is 2D.
484487 shift_attention_mask = input_attention_mask [..., 1 :]
485488 shift_logits = shift_logits [shift_attention_mask .to (logits .device ) != 0 ].contiguous ()
486- shift_labels = shift_labels [shift_attention_mask .to (logits .device ) != 0 ].contiguous ()
489+ shift_labels = shift_labels [shift_attention_mask .to (shift_labels .device ) != 0 ].contiguous ()
487490 else :
488491 shift_logits = shift_logits .contiguous ()
489492 shift_labels = shift_labels .contiguous ()
0 commit comments