Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: torch_xla/csrc/data_ops.cpp:185 : Check failed: input_sizes.size() <= output_sizes.size() #8346

Open
samuele-bortolato1 opened this issue Oct 31, 2024 · 0 comments

Comments

@samuele-bortolato1
Copy link

🐛 Bug

I was trying to run a model from huggingface on TP.
In normal mode it would wait forever without returning any output, while in eager mode it returned an error

To Reproduce

# uninstall tensorflow when on kaggle
!pip uninstall tensorflow -y
!pip install -U transformers accelerate

import requests

import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection

import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

torch_xla.experimental.eager_mode(enable=True)

model_id = "IDEA-Research/grounding-dino-tiny"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
# Check for cats and remote controls
text = "a cat. a remote control."

inputs = processor(images=image, text=text, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = model(**inputs)

results = processor.post_process_grounded_object_detection(
    outputs,
    inputs.input_ids,
    box_threshold=0.4,
    text_threshold=0.3,
    target_sizes=[image.size[::-1]]
)
print(results)

returned the error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 14
     12 inputs = processor(images=image, text=text, return_tensors="pt").to(device)
     13 with torch.no_grad():
---> 14     outputs = model(**inputs)
     16 results = processor.post_process_grounded_object_detection(
     17     outputs,
     18     inputs.input_ids,
   (...)
     21     target_sizes=[image.size[::-1]]
     22 )
     23 print(results)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /usr/local/lib/python3.10/site-packages/transformers/models/grounding_dino/modeling_grounding_dino.py:2582, in GroundingDinoForObjectDetection.forward(self, pixel_values, input_ids, token_type_ids, attention_mask, pixel_mask, encoder_outputs, output_attentions, output_hidden_states, return_dict, labels)
   2579     attention_mask = torch.ones_like(input_ids)
   2581 # First, sent images through Grounding DINO base model to obtain encoder + decoder outputs
-> 2582 outputs = self.model(
   2583     pixel_values=pixel_values,
   2584     input_ids=input_ids,
   2585     token_type_ids=token_type_ids,
   2586     attention_mask=attention_mask,
   2587     pixel_mask=pixel_mask,
   2588     encoder_outputs=encoder_outputs,
   2589     output_attentions=output_attentions,
   2590     output_hidden_states=output_hidden_states,
   2591     return_dict=return_dict,
   2592 )
   2594 idx = 5 + (1 if output_attentions else 0) + (1 if output_hidden_states else 0)
   2595 enc_text_hidden_state = outputs.encoder_last_hidden_state_text if return_dict else outputs[idx]

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /usr/local/lib/python3.10/site-packages/transformers/models/grounding_dino/modeling_grounding_dino.py:2260, in GroundingDinoModel.forward(self, pixel_values, input_ids, token_type_ids, attention_mask, pixel_mask, encoder_outputs, output_attentions, output_hidden_states, return_dict)
   2255 output_hidden_states = (
   2256     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
   2257 )
   2258 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 2260 text_self_attention_masks, position_ids = generate_masks_with_special_tokens_and_transfer_map(input_ids)
   2262 if attention_mask is None:
   2263     attention_mask = torch.ones_like(input_ids)

File /usr/local/lib/python3.10/site-packages/transformers/models/grounding_dino/modeling_grounding_dino.py:2040, in generate_masks_with_special_tokens_and_transfer_map(input_ids)
   2037 idxs = torch.nonzero(special_tokens_mask)
   2039 # generate attention mask and positional ids
-> 2040 attention_mask = torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat(batch_size, 1, 1)
   2041 position_ids = torch.zeros((batch_size, num_token), device=input_ids.device)
   2042 previous_col = 0

RuntimeError: torch_xla/csrc/data_ops.cpp:185 : Check failed: input_sizes.size() <= output_sizes.size() (2 vs. 1)
*** Begin stack trace ***
	tsl::CurrentStackTrace()
	torch_xla::BuildExpand(xla::XlaOp, absl::lts_20230802::Span<long const>)
	
	torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)
	
	
	torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
	torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
	torch_xla::Expand::Expand(torch::lazy::Value const&, std::vector<long, std::allocator<long> >)
	std::shared_ptr<torch::lazy::Node> torch::lazy::MakeNode<torch_xla::Expand, torch::lazy::Value const&, std::vector<long, std::allocator<long> > >(torch::lazy::Value const&, std::vector<long, std::allocator<long> >&&)
	
	torch_xla::tensor_methods::copy_(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >&)
	torch_xla::XLANativeFunctions::_copy_from(at::Tensor const&, at::Tensor const&, bool)
	torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	
	at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	
	at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	
	at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	
	
	at::_ops::_to_copy::call(at::Tensor const&, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, bool, std::optional<c10::MemoryFormat>)
	at::native::to(at::Tensor const&, c10::ScalarType, bool, bool, std::optional<c10::MemoryFormat>)
	
	at::_ops::to_dtype::call(at::Tensor const&, c10::ScalarType, bool, bool, std::optional<c10::MemoryFormat>)
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	_PyObject_FastCallDictTstate
	_PyObject_Call_Prepend
	
	_PyObject_MakeTpCall
	
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	_PyObject_FastCallDictTstate
	_PyObject_Call_Prepend
	
	_PyObject_Call
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	PyVectorcall_Call
	_PyEval_EvalFrameDefault
	
	
	
	_PyEval_EvalFrameDefault
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	_PyEval_EvalFrameDefault
	
	PyEval_EvalCode
	
	
	_PyEval_EvalFrameDefault
	
	_PyEval_EvalFrameDefault
	
	
	Py_RunMain
	Py_BytesMain
	
*** End stack trace ***

Expected behavior

It shouldn't break, or at least it should give a more informative error.

Environment

Tested on Kaggle with a TPU v3-8
torch.version = '2.4.0+cu121' (the preinstalled one)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant