You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was trying to run a model from huggingface on TP.
In normal mode it would wait forever without returning any output, while in eager mode it returned an error
To Reproduce
# uninstall tensorflow when on kaggle
!pip uninstall tensorflow -y
!pip install -U transformers accelerate
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
torch_xla.experimental.eager_mode(enable=True)
model_id = "IDEA-Research/grounding-dino-tiny"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
# Check for cats and remote controls
text = "a cat. a remote control."
inputs = processor(images=image, text=text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=0.4,
text_threshold=0.3,
target_sizes=[image.size[::-1]]
)
print(results)
returned the error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 14
12 inputs = processor(images=image, text=text, return_tensors="pt").to(device)
13 with torch.no_grad():
---> 14 outputs = model(**inputs)
16 results = processor.post_process_grounded_object_detection(
17 outputs,
18 inputs.input_ids,
(...)
21 target_sizes=[image.size[::-1]]
22 )
23 print(results)
File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.10/site-packages/transformers/models/grounding_dino/modeling_grounding_dino.py:2582, in GroundingDinoForObjectDetection.forward(self, pixel_values, input_ids, token_type_ids, attention_mask, pixel_mask, encoder_outputs, output_attentions, output_hidden_states, return_dict, labels)
2579 attention_mask = torch.ones_like(input_ids)
2581 # First, sent images through Grounding DINO base model to obtain encoder + decoder outputs
-> 2582 outputs = self.model(
2583 pixel_values=pixel_values,
2584 input_ids=input_ids,
2585 token_type_ids=token_type_ids,
2586 attention_mask=attention_mask,
2587 pixel_mask=pixel_mask,
2588 encoder_outputs=encoder_outputs,
2589 output_attentions=output_attentions,
2590 output_hidden_states=output_hidden_states,
2591 return_dict=return_dict,
2592 )
2594 idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
2595 enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]
File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File /usr/local/lib/python3.10/site-packages/transformers/models/grounding_dino/modeling_grounding_dino.py:2260, in GroundingDinoModel.forward(self, pixel_values, input_ids, token_type_ids, attention_mask, pixel_mask, encoder_outputs, output_attentions, output_hidden_states, return_dict)
2255 output_hidden_states = (
2256 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
2257 )
2258 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 2260 text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids)
2262 if attention_mask is None:
2263 attention_mask = torch.ones_like(input_ids)
File /usr/local/lib/python3.10/site-packages/transformers/models/grounding_dino/modeling_grounding_dino.py:2040, in generate_masks_with_special_tokens_and_transfer_map(input_ids)
2037 idxs = torch.nonzero(special_tokens_mask)
2039 # generate attention mask and positional ids
-> 2040 attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(batch_size, 1, 1)
2041 position_ids = torch.zeros((batch_size, num_token), device=input_ids.device)
2042 previous_col = 0
RuntimeError: torch_xla/csrc/data_ops.cpp:185 : Check failed: input_sizes.size() <= output_sizes.size() (2 vs. 1)
*** Begin stack trace ***
tsl::CurrentStackTrace()
torch_xla::BuildExpand(xla::XlaOp, absl::lts_20230802::Span<long const>)
torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)
torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
torch_xla::Expand::Expand(torch::lazy::Value const&, std::vector<long, std::allocator<long> >)
std::shared_ptr<torch::lazy::Node> torch::lazy::MakeNode<torch_xla::Expand, torch::lazy::Value const&, std::vector<long, std::allocator<long> > >(torch::lazy::Value const&, std::vector<long, std::allocator<long> >&&)
torch_xla::tensor_methods::copy_(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >&)
torch_xla::XLANativeFunctions::_copy_from(at::Tensor const&, at::Tensor const&, bool)
torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
at::native::to(at::Tensor const&, c10::ScalarType, bool, bool, std::optional<c10::MemoryFormat>)
at::_ops::to_dtype::call(at::Tensor const&, c10::ScalarType, bool, bool, std::optional<c10::MemoryFormat>)
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_Call_Prepend
_PyObject_Call
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyVectorcall_Call
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyEval_EvalFrameDefault
_PyEval_EvalFrameDefault
Py_RunMain
Py_BytesMain
*** End stack trace ***
Expected behavior
It shouldn't break, or at least it should give a more informative error.
Environment
Tested on Kaggle with a TPU v3-8
torch.version = '2.4.0+cu121' (the preinstalled one)
The text was updated successfully, but these errors were encountered:
🐛 Bug
I was trying to run a model from huggingface on TP.
In normal mode it would wait forever without returning any output, while in eager mode it returned an error
To Reproduce
returned the error
Expected behavior
It shouldn't break, or at least it should give a more informative error.
Environment
Tested on Kaggle with a TPU v3-8
torch.version = '2.4.0+cu121' (the preinstalled one)
The text was updated successfully, but these errors were encountered: