-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Not found error for torch_sparse::ptr2ind in torchscript #1718
Comments
For executing in a C++ environment, please ensure that you also have the C++ extensions installed, see here. |
Thank you for comment, @rusty1s.
Since I'm not familiar with |
That looks good to me. Does that work for you? |
No, my program cannot found In
I think In addition, I cannot found prt2ind in header files in ${PACKAGE_PREFIX_DIR}/include. |
I added a header file to |
Thank you, @rusty1s. I tried to run C++ code with new header file. Importing header file in my C++ code was successful after changing the part below to following lines std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> // the line I added
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
torch::Tensor idx, int64_t num_neighbors, bool replace); However, my code cannot use ptr2ind function. |
Thanks, that was indeed a typo, it is fixed now :) I also added this line to the CMake file, and hope that is solves your issues. |
Unfortunately, my C++ code cannot find |
I encounter the same problem, it`s there any update to fix it? |
I have the same issue, but loading from a TorchScript model from Python code. |
I still have the same problem. I built TorchSparse and linked it to my C++ code but it still has the same error. |
I will look into it. Any advice welcome :) |
I tried to link it this way to prevent the default library being linked: |
I re-checked that, and it works fine for me following the setup described in
Inside
|
@rusty1s I encountered the same issue after following the steps you suggested. I used the following example as a test:
I tried to play around with the CMakeLists.txt but always get the same error about "Unknown builtin op: torch_sparse::ptr2ind." |
This might be related to missing namespaces in |
Thank you for the prompt response. I was not able to resolve it by adding namespaces. However I found a solution by compiling the c++ code together with the source files of torch-sparse. |
@zehuag can you share a bit more of what you did? I tried copying and modifying the pytorch_scatter/CMakeLists.txt into my main CMakeLists.txt but then I get errors building in Visual Studio: |
@zehuag or @rusty1s could you show how you compile C++ code together with the source file of TorchSparse? I am running into the same issue, and I am quite inexperienced in C++, and I also would not know how to add a namespace. My C++ code#include <torch/script.h>
#include <torchsparse/sparse.h>
#include <iostream>
#include <torch/torch.h>
#include <torchscatter/scatter.h>
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
} My CMakeLists.txt
Hopefully, one of you can help me with it. |
Can you share some more information about the error you are experiencing? |
Of course @rusty1s. It is the same issue, namely, that it can not find torch_sparse::ptr2ind.
I used the following versions for each library:
I had to downgrade some libraries because of problems with half-precision float support (Link to issue in Torch Scatter). The output from Cmake also shows no problems with finding the packages: Output Cmake
Hopefully, this is enough information for you; otherwise, please let me know what kind of information you need such that I can add the missing information to the post. |
Same issue here with a Makefile project instead of CMake.
|
Thanks for the report and sorry for the delay in fixing this. @mananshah99 and I will look into this ASAP. |
Do you have an ETA on a possible fix? I'm experiencing the exact same issue and I'm stuck on my work until I can load models made with PyG. |
I try to look into this in the remainder of this week, but any help and assistance here is highly appreciated. |
After fiddling for a while with my Makefile, I tried with a CMake project and started off using this cmake_minimum_required(VERSION 3.20)
project(ptr_test CXX)
set(CMAKE_CXX_STANDARD 14)
find_package(TorchSparse REQUIRED)
add_executable(ptr_test test_ptr.cpp)
target_link_libraries(ptr_test TorchSparse::TorchSparse) Trying with that file, I realized that the torchspace cmake_minimum_required(VERSION 3.20)
project(ptr_test CXX)
set(CMAKE_CXX_STANDARD 14)
find_package(TorchSparse REQUIRED)
# The following two lines were required for me
include_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/include)
link_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/lib)
add_executable(ptr_test test_ptr.cpp)
target_link_libraries(ptr_test TorchSparse::TorchSparse) Which left me with a second error: libc++abi: terminating with uncaught exception of type torch::jit::ErrorReport:
Unknown builtin op: torch_scatter::segment_sum_csr.
Could not find any similar ops to torch_scatter::segment_sum_csr. This op may not exist or may not be currently supported in TorchScript.
:
File "/opt/homebrew/Caskroom/miniforge/base/envs/work_env/lib/python3.8/site-packages/torch_scatter/segment_csr.py", line 8
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Serialized File "code/__torch__/torch_scatter/segment_csr.py", line 40
indptr: Tensor,
out: Optional[Tensor]=None) -> Tensor:
_11 = ops.torch_scatter.segment_sum_csr(src, indptr, out)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
return _11
def segment_mean_csr(src: Tensor,
'segment_sum_csr' is being compiled since it was called from 'segment_csr'
Serialized File "code/__torch__/torch_scatter/segment_csr.py", line 10
out: Optional[Tensor]=None,
reduce: str="sum") -> Tensor:
_1 = __torch__.torch_scatter.segment_csr.segment_sum_csr
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_2 = __torch__.torch_scatter.segment_csr.segment_mean_csr
_3 = __torch__.torch_scatter.segment_csr.segment_min_csr
'segment_csr' is being compiled since it was called from 'SparseStorage.coalesce'
Serialized File "code/__torch__/torch_sparse/storage.py", line 463
def coalesce(self: __torch__.torch_sparse.storage.SparseStorage,
reduce: str="add") -> __torch__.torch_sparse.storage.SparseStorage:
_58 = __torch__.torch_scatter.segment_csr.segment_csr
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
_col = self._col
_col3 = self._col So I did the same thing with TorchScatter: cmake_minimum_required(VERSION 3.20)
project(ptr_test CXX)
set(CMAKE_CXX_STANDARD 14)
find_package(TorchSparse REQUIRED)
find_package(TorchScatter REQUIRED)
# The following two lines were required for me
include_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/include)
link_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/lib)
add_executable(ptr_test test_ptr.cpp)
target_link_libraries(ptr_test TorchSparse::TorchSparse)
target_link_libraries(ptr_test TorchScatter::TorchScatter) Now both errors are gone. |
I can indeed reproduce the issue. This Related as well: https://stackoverflow.com/questions/65705160/torch-vision-c-interface-error-unknown-builtin-op-torchvisionnms |
I finally fixed this issue, both for
Finally, I added a fully-working "PyG in C++" example to Thanks for all the help and sorry that it took me so long to fix :) Hope that all issues are now resolved! |
Can confirm it's fixed for me. I was able to remove the |
Sorry, I know this issue is closed but for some reason I can't get it to work using a Makefile. CMake works fine but I still get the same error when I try to compile using a Makefile. Do you have a simple working example with a Makefile? |
I don't have a |
If you get the same error you likely forgot to link the folder where your torch_scatter and torch_sparse libraries reside. |
When I ran the cpp examples, I got under errors:
|
Please see my reply in #8882. |
❓ Questions & Help
I tried to use pytorch model with
MessagePassing
layer in C++ code.As described in pytorch_geometric documentation,
I generate torch model with my own MP layer and successfully convert the model.
But in the process of executing C++ code, I face the error like below:
Since I have no experience of pytorch jit, I cannot find any clue to solve this.
How can I handle this error?
The text was updated successfully, but these errors were encountered: