Skip to content

Transformers PaliGemma evaluate and compute_loss fail with tensors/device errors #35990

@BlGene

Description

@BlGene

System Info

My versions are:

Python Version: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 16:05:46) [GCC 13.3.0]
Torch Version: 2.5.1+cu124
CUDA Available: True
CUDA Device Count: 2
GPU Name: NVIDIA GeForce RTX 3090
Transformers Version: 4.48.1
Tokenizers Version: 0.21.0
Accelerate Version: 1.3.0

Who can help?

@ArthurZucker , @amyeroberts, @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm loading a PaliGemma2 model google/paligemma2-3b-pt-224 and trying to fine-tune using Trainer/Seq2SeqTrainer. If I add evaluation, this fails. After doing some digging, I found that this only happens if the model is in evaluate mode.

batch = [valid_dataset[i] for i in range(8)]
inputs = collate_fn(batch)
#generate_ids = model.generate(**inputs, max_length=286+30)
trainer.model.train()
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("works")
trainer.model.train(False)
trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
print("fails.")

I've worked around it by mokey-patching compute_loss_context_manager as follows:

orig_context_manager = trainer.compute_loss_context_manager
class TempTrainContext(object):
    def __init__(self, trainer):
        self.trainer = trainer
        self.orig_context_manager = trainer.compute_loss_context_manager
    def __enter__(self):
        self.orig_context_inst = self.orig_context_manager()
        self.orig_context_inst.__enter__()
        self.training_enter = self.trainer.model.training
        self.trainer.model.train()
    def __exit__(self, type, value, traceback):
        self.trainer.model.train(self.training_enter)
        self.orig_context_inst.__exit__(type, value, traceback)
    def __call__(self):
        return self

trainer.compute_loss_context_manager = TempTrainContext(trainer)

(Bonus question: Is this safe to do, or will I train on the test set?)
Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[13], line 8
      6 print("works")
      7 trainer.model.train(False)
----> 8 trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
      9 print("fails.")
     12 orig_context_manager = trainer.compute_loss_context_manager

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3729         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3730     inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
   3732 # Save past state if it exists
   3733 # TODO: this needs to be fixed and made cleaner later.
   3734 if self.args.past_index >= 0:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
    525     labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
    527 causal_mask = self._update_causal_mask(
    528     attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
    529 )
--> 530 outputs = self.language_model(
    531     attention_mask=causal_mask,
    532     position_ids=position_ids,
    533     past_key_values=past_key_values,
    534     inputs_embeds=inputs_embeds,
    535     use_cache=use_cache,
    536     output_attentions=output_attentions,
    537     output_hidden_states=output_hidden_states,
    538     return_dict=return_dict,
    539     cache_position=cache_position,
    540     num_logits_to_keep=num_logits_to_keep,
    541 )
    543 logits = outputs.logits
    544 loss = None

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
    840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
    843     input_ids=input_ids,
    844     attention_mask=attention_mask,
    845     position_ids=position_ids,
    846     past_key_values=past_key_values,
    847     inputs_embeds=inputs_embeds,
    848     use_cache=use_cache,
    849     output_attentions=output_attentions,
    850     output_hidden_states=output_hidden_states,
    851     return_dict=return_dict,
    852     cache_position=cache_position,
    853 )
    855 hidden_states = outputs[0]
    856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
    617     layer_outputs = self._gradient_checkpointing_func(
    618         decoder_layer.__call__,
    619         hidden_states,
   (...)
    626         cache_position,
    627     )
    628 else:
--> 629     layer_outputs = decoder_layer(
    630         hidden_states,
    631         position_embeddings=position_embeddings,
    632         attention_mask=causal_mask,
    633         position_ids=position_ids,
    634         past_key_value=past_key_values,
    635         output_attentions=output_attentions,
    636         use_cache=use_cache,
    637         cache_position=cache_position,
    638         **flash_attn_kwargs,
    639     )
    641 hidden_states = layer_outputs[0]
    643 if output_attentions:

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    296 hidden_states = self.input_layernorm(hidden_states)
    298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
    300     hidden_states=hidden_states,
    301     position_embeddings=position_embeddings,
    302     attention_mask=attention_mask,
    303     position_ids=position_ids,
    304     past_key_value=past_key_value,
    305     output_attentions=output_attentions,
    306     use_cache=use_cache,
    307     cache_position=cache_position,
    308 )
    309 hidden_states = self.post_attention_layernorm(hidden_states)
    310 hidden_states = residual + hidden_states

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
    221 if past_key_value is not None:
    222     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    223     cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    226 attention_interface: Callable = eager_attention_forward
    227 if self.config._attn_implementation != "eager":

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1714 else:
   1715     update_fn = self._static_update
-> 1717 return update_fn(
   1718     cache_position,
   1719     layer_idx,
   1720     key_states,
   1721     value_states,
   1722     k_out,
   1723     v_out,
   1724     k_out.shape[2],
   1725 )

File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
   1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694     k_out[:, :, cache_position] = key_states
   1695     v_out[:, :, cache_position] = value_states
   1697     self.key_cache[layer_idx] = k_out

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!"

Error of Evaluator (bottom half of file): https://gist.github.com/BlGene/607c7bee450e03835aa2bf0d2fd2959a

Expected behavior

Training runs with evaluation enabled.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions