-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
libtorch C++, fasterrcnn_resnet50_fpn module.forward() Assert #3349
Comments
Same issue with maskrcnn_resnet50_fpn. Any ideas? |
Are you using the debug versions of libtorch and torchvision? |
I've tried with both debug and release. But none of the two works |
Try passing in a list of tensor images(c x h x w) instead of a single tensor that contains a batch of images: auto imageList = c10::List<torch::Tensor>({imageTensors...});
std::vector<torch::jit::IValue> inputs;
inputs.emplace_back(imageList);
torch::jit::IValue output = module.forward(inputs); |
For reference, this is what I use to convert a cv::Mat to a torch tensor: torch::Tensor createImageTensor(const cv::Mat &image)
{
cv::Mat rgbImage;
cv::cvtColor(image, rgbImage, cv::COLOR_BGR2RGB);
torch::Tensor tensorImage = torch::from_blob(
rgbImage.data, {rgbImage.rows, rgbImage.cols, 3},
torch::TensorOptions().dtype(torch::kByte).requires_grad(false));
tensorImage = tensorImage.to(torch::kFloat);
tensorImage /= 255.0;
tensorImage = tensorImage.transpose(0, 1).transpose(0, 2).contiguous();
return tensorImage;
} |
Thanks for you answer, this is now my code:
Unfortunately it gives me an assert again. Looking at the call stack where the exception is thrown:
In file "libtorch-win-shared-with-deps-debug-1.7.1+cpu\libtorch\include\torch\csrc\jit\api\module.h" Line 112
input.size() = 0 |
Can you verify that you can correctly run the tracing test ? |
Did you modify the source code of the test? It seems like it's trying to load a file called |
You shouldn't have to modify the source code. the pt file is generated by the python file in the tracing directory, so make sure you run that one first. |
I've compiled the test using cmake, it runs, the model is correctly loaded and the forward gives no problem. When I use the model traced with the test in Visual Studio I am back to the original issue, the inference does not work. |
Can you share the python code you use to generate the torchscript file? |
This is my code
|
That looks fine. If you can't correctly run your |
Maybe I accidently find out a solution. I came across a similar problem like yours. the fine code: |
🐛 Bug
module.forward() launches Debug assert
File: minkernel\crts\ucrt\src\appcrt\heap\debug_heap.cpp
Line: 966
Expression: __acrt_first_block == header
To Reproduce
Loaded scripted model with
Loaded image into tensor with
Both model and tensor seem to be loaded correctly anyway
does not work.
call stack is:
Environment
OS: Microsoft Windows 7 Professional
Language: C++
CMake version: version 3.17.1
Python version: 3.7 (64-bit runtime)
Is CUDA available: N/A
numpy==1.18.5
torch==1.7.1+cpu
torchaudio==0.7.2
torchvision==nightly
Python version:
Additional context
cc @vfdev-5
The text was updated successfully, but these errors were encountered: