Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not found error for torch_sparse::ptr2ind in torchscript #1718

Closed
Nanco-L opened this issue Oct 13, 2020 · 34 comments
Closed

Not found error for torch_sparse::ptr2ind in torchscript #1718

Nanco-L opened this issue Oct 13, 2020 · 34 comments

Comments

@Nanco-L
Copy link

Nanco-L commented Oct 13, 2020

❓ Questions & Help

I tried to use pytorch model with MessagePassing layer in C++ code.
As described in pytorch_geometric documentation,
I generate torch model with my own MP layer and successfully convert the model.

But in the process of executing C++ code, I face the error like below:

Unknown builtin op: torch_sparse::ptr2ind.
Could not find any similar ops to torch_sparse::ptr2ind. This op may not exist or may not be currently supported in TorchScript.
:
  File "/home/sr6/kyuhyun9.lee/env_ML/lib/python3.6/site-packages/torch_sparse/storage.py", line 166
        rowptr = self._rowptr
        if rowptr is not None:
            row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            self._row = row
            return row
Serialized   File "code/__torch__/torch_sparse/storage.py", line 825
      if torch.__isnot__(rowptr, None):
        rowptr13 = unchecked_cast(Tensor, rowptr)
        row15 = ops.torch_sparse.ptr2ind(rowptr13, torch.numel(self._col))
                ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        self._row = row15
        _150, _151 = True, row15
'SparseStorage.row' is being compiled since it was called from 'SparseStorage.__init__'
  File "/home/sr6/kyuhyun9.lee/env_ML/lib/python3.6/site-packages/torch_sparse/storage.py", line 133
        if not is_sorted:
            idx = self._col.new_zeros(self._col.numel() + 1)
            idx[1:] = self._sparse_sizes[1] * self.row() + self._col
                                              ~~~~~~~~ <--- HERE
            if (idx[1:] < idx[:-1]).any():
                perm = idx[1:].argsort()
Serialized   File "code/__torch__/torch_sparse/storage.py", line 267
      idx = torch.new_zeros(self._col, [_29], dtype=None, layout=None, device=None, pin_memory=None)
      _30 = (self._sparse_sizes)[1]
      _31 = torch.add(torch.mul((self).row(), _30), self._col, alpha=1)
                                 ~~~~~~~~~~ <--- HERE
      _32 = torch.slice(idx, 0, 1, 9223372036854775807, 1)
      _33 = torch.copy_(_32, _31, False)
'SparseStorage.__init__' is being compiled since it was called from 'GINLayerJittable_d54f76.__check_input____1'
Serialized   File "code/__torch__/GINLayerJittable_d54f76.py", line 40
      pass
    return the_size
  def __check_input____1(self: __torch__.GINLayerJittable_d54f76.GINLayerJittable_d54f76,
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~...  <--- HERE
    edge_index: __torch__.torch_sparse.tensor.SparseTensor,
    size: Optional[Tuple[int, int]]) -> List[Optional[int]]:

Aborted (core dumped)

Since I have no experience of pytorch jit, I cannot find any clue to solve this.
How can I handle this error?

@rusty1s
Copy link
Member

rusty1s commented Oct 13, 2020

For executing in a C++ environment, please ensure that you also have the C++ extensions installed, see here.

@Nanco-L
Copy link
Author

Nanco-L commented Oct 14, 2020

Thank you for comment, @rusty1s.
I build pytorch_sparse (though the link in your comment is about pytorch_scatter, I found ptr2ind function in pytorch_sparse. Thus, I build pytorch_sparse using cmake) in my workspace.
And add following lines in my CMakeLists.txt file:

find_package(TorchSparse REQUIRED)
target_link_libraries(lmp "${TorchSparse_LIBRARIES}")

Since I'm not familiar with cmake, I'm not sure it is correct.
Is there anything to fix?

@rusty1s
Copy link
Member

rusty1s commented Oct 14, 2020

That looks good to me. Does that work for you?

@Nanco-L
Copy link
Author

Nanco-L commented Oct 14, 2020

No, my program cannot found ptr2ind function even though PytorchSparse is found in cmake configure process.
Is it correct to use target_link_libraries like I write?

In TorchSparseConfig.cmake, I found following lines.

set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/include")
set(${PN}_LIBRARY "")

I think "${TorchSparse_LIBRARIES}" variable written in my previous comment need to be here, but it is missing.
So, I tried to change the second line like set(${PN}_LIBRARIES "${PACKAGE_PREFIX_DIR}/lib64")
(the directory that libtorchsparse.so is stored)
But it is still not work.

In addition, I cannot found prt2ind in header files in ${PACKAGE_PREFIX_DIR}/include.
I found the function in csrc/convert.cpp.
But there are no headers to indicate the function (Unlike torchsparse, torchscatter has scatter.h..).
Is there anything wrong?

@rusty1s
Copy link
Member

rusty1s commented Oct 17, 2020

I added a header file to torch-sparse. Please try again :)

@Nanco-L
Copy link
Author

Nanco-L commented Oct 19, 2020

Thank you, @rusty1s. I tried to run C++ code with new header file.
But I still face same error. :(

Importing header file in my C++ code was successful after changing the part below
https://github.com/rusty1s/pytorch_sparse/blob/2bcc13ed3dd2cac1a023601a622307b74efa4db8/csrc/sparse.h#L24

to following lines

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> // the line I added
sample_adj(torch::Tensor rowptr, torch::Tensor col, torch::Tensor rowcount,
           torch::Tensor idx, int64_t num_neighbors, bool replace);

However, my code cannot use ptr2ind function.
And according to the error message, there is a possibility that ptr2ind is not supported in TorchScript.
So, I want to ask a question..
Did you try to run C++ code with torch_geometric message passing network in your environment?
If so, sharing the sample code and makefile can be really helpful.

@rusty1s
Copy link
Member

rusty1s commented Oct 19, 2020

Thanks, that was indeed a typo, it is fixed now :) I also added this line to the CMake file, and hope that is solves your issues.

@Nanco-L
Copy link
Author

Nanco-L commented Oct 21, 2020

Unfortunately, my C++ code cannot find ptr2ind after the last update.
Since there are no problems in making torchscript file from pytorch_geometric model,
I suspect that torch_sparse code do not support C++ currently.
Please clarify this.. Thank you so much!

@lianxh
Copy link

lianxh commented Jan 14, 2021

I encounter the same problem, it`s there any update to fix it?

@mathieuorhan
Copy link

I have the same issue, but loading from a TorchScript model from Python code.
Found a workaround: installing torch_sparse / scatter and import torch_sparse before loading the model.

@dariush-salami
Copy link

I still have the same problem. I built TorchSparse and linked it to my C++ code but it still has the same error.

@rusty1s
Copy link
Member

rusty1s commented Feb 4, 2021

I will look into it. Any advice welcome :)

@dariush-salami
Copy link

I tried to link it this way to prevent the default library being linked:
find_package(TorchSparse REQUIRED PATHS pytorch_sparse/build NO_DEFAULT_PATH)
I got this error.
CMake Error at pytorch_sparse/build/TorchSparseConfig.cmake:47 (include): include could not find load file: /home/researcher/PycharmProjects/RI4DPC/hololens/example-app/pytorch_sparse/build/TorchSparseTargets.cmake Call Stack (most recent call first): CMakeLists.txt:6 (find_package) -- Found Python3: /usr/include/python3.6m (found version "3.6.9") found components: Development Development.Module Development.Embed CMake Error at pytorch_sparse/build/TorchSparseConfig.cmake:55 (target_link_libraries): Cannot specify link libraries for target "TorchSparse::TorchSparse" which is not built by this project. Call Stack (most recent call first): CMakeLists.txt:6 (find_package) -- Configuring incomplete, errors occurred!
It seems the file below does not exist:
include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake")

@rusty1s
Copy link
Member

rusty1s commented Feb 16, 2021

I re-checked that, and it works fine for me following the setup described in torchvision. After installing torch-sparse, do the following:

  1. Create main.cpp:
#include <iostream>
#include <torch/torch.h>
#include <torchsparse/sparse.h>

int main() {
  torch::Tensor tensor = torch::tensor({0, 0, 0, 0, 1, 1, 1});
  std::cout << tensor << std::endl;
  std::cout << ind2ptr(tensor, 2) << std::endl;
}
  1. Create CMakeLists.txt:
cmake_minimum_required(VERSION 3.10)
project(hello-world)

find_package(TorchSparse REQUIRED)

add_executable(hello-world main.cpp)

target_compile_features(hello-world PUBLIC cxx_range_for)
target_link_libraries(hello-world TorchSparse::TorchSparse)
set_property(TARGET hello-world PROPERTY CXX_STANDARD 14)

Inside build/, run cmake -DCMAKE_PREFIX_PATH="<PATH_TO_LIBTORCH>" .. and cmake --build . Executing ./hello-world should yield the following:

 0
 0
 0
 0
 1
 1
 1
[ CPULongType{7} ]
 0
 4
 7
[ CPULongType{3} ]

@zehuag
Copy link

zehuag commented Sep 14, 2021

@rusty1s I encountered the same issue after following the steps you suggested. I used the following example as a test:

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import GATConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GATConv(16, 8, heads=1, dropout=0.6).jittable()

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        return F.log_softmax(x, dim=1)`

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
model = torch.jit.script(model)

print(model.graph)
model.save("example_gat.pt")

I tried to play around with the CMakeLists.txt but always get the same error about "Unknown builtin op: torch_sparse::ptr2ind."

@rusty1s
Copy link
Member

rusty1s commented Sep 15, 2021

This might be related to missing namespaces in torch_sparse/csrc/sparse.h. Can you check if adding them resolves these issues?

@zehuag
Copy link

zehuag commented Sep 16, 2021

Thank you for the prompt response. I was not able to resolve it by adding namespaces. However I found a solution by compiling the c++ code together with the source files of torch-sparse.

@DBraun
Copy link

DBraun commented Nov 13, 2021

@zehuag can you share a bit more of what you did? I tried copying and modifying the pytorch_scatter/CMakeLists.txt into my main CMakeLists.txt but then I get errors building in Visual Studio:
Severity Code Description Project File Line Suppression State Error MSB3721 The command ""C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3\bin\nvcc.exe" -gencode=arch=compute_75,code=\"sm_75,compute_75\" --use-local-env -ccbin "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30133\bin\HostX64\x64" -x cu -rdc=true -I"C:\repos\graph-project\src" -IE:\tools\mylibtorch\torch\include -IE:\tools\mylibtorch\torch\include\torch\csrc\api\include -I"C:\Program Files\NVIDIA Corporation\NvToolsExt\\include" -I"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3\include" -I"C:\Program Files\NVIDIA Corporation\NvToolsExt\include" -IC:\Python39\include -I"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3\include" --keep-dir x64\Release -maxrregcount=0 --machine 64 --compile -cudart static -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --Werror cross-execution-space-call --no-host-device-move-forward --expt-relaxed-constexpr --expt-extended-lambda --expt-relaxed-constexpr /MD /W3 /Y- -std=c++17 -Xcompiler="/EHsc -Ob2" -D_WINDOWS -DONNX_NAMESPACE=onnx_c2 -DNDEBUG -DWIN32 -D_USRDLL -D__CUDA_NO_HALF_OPERATORS__ -DWITH_CUDA -DUSE_DISTRIBUTED -DUSE_C10D_GLOO -D"CMAKE_INTDIR=\"Release\"" -DGenGraphTOP_EXPORTS -DNOMINMAX -DNDEBUG -DWIN32 -D_WINDOWS -D_USRDLL -D__CUDA_NO_HALF_OPERATORS__ -DWITH_CUDA -DUSE_DISTRIBUTED -DUSE_C10D_GLOO -D"CMAKE_INTDIR=\"Release\"" -DGenGraphTOP_EXPORTS -D_WINDLL -D_MBCS -Xcompiler "/EHsc /W3 /nologo /O2 /FdGenGraphTOP.dir\Release\vc142.pdb /FS /MD /GR" -o GenGraphTOP.dir\Release\segment_csr_cuda.obj "C:\repos\graph-project\pytorch_scatter\csrc\cuda\segment_csr_cuda.cu"" exited with code 1. GenGraphTOP C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\CUDA 11.3.targets 785

@jesserem
Copy link

@zehuag or @rusty1s could you show how you compile C++ code together with the source file of TorchSparse? I am running into the same issue, and I am quite inexperienced in C++, and I also would not know how to add a namespace.

My C++ code

#include <torch/script.h>
#include <torchsparse/sparse.h>
#include <iostream>
#include <torch/torch.h>
#include <torchscatter/scatter.h>

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok\n";
} 

My CMakeLists.txt

cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
project(example-app)

find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

find_package(TorchScatter REQUIRED)
find_package(TorchSparse REQUIRED)

add_executable(example-app example-app.cpp)

target_compile_features(example-app PUBLIC cxx_range_for)
target_link_libraries(example-app "${TORCH_LIBRARIES}" TorchScatter::TorchScatter TorchSparse::TorchSparse)
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)

Hopefully, one of you can help me with it.

@rusty1s
Copy link
Member

rusty1s commented Jan 30, 2022

Can you share some more information about the error you are experiencing?

@jesserem
Copy link

Of course @rusty1s. It is the same issue, namely, that it can not find torch_sparse::ptr2ind.

Terminate called after throwing an instance of 'torch::jit::ErrorReport'
  what():  
Unknown builtin op: torch_sparse::ptr2ind.
Could not find any similar ops to torch_sparse::ptr2ind. This op may not exist or may not be currently supported in TorchScript.
:
  File "/home/jesse/anaconda3/envs/dtkc_rl_19/lib/python3.9/site-packages/torch_sparse/storage.py", line 164
        rowptr = self._rowptr
        if rowptr is not None:
            row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            self._row = row
            return row
Serialized   File "code/__torch__/torch_sparse/storage.py", line 804
      if torch.__isnot__(rowptr, None):
        rowptr14 = unchecked_cast(Tensor, rowptr)
        row18 = ops.torch_sparse.ptr2ind(rowptr14, torch.numel(self._col))
                ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        self._row = row18
        _126 = row18
'SparseStorage.row' is being compiled since it was called from 'SparseStorage.__init__'
  File "/home/jesse/anaconda3/envs/dtkc_rl_19/lib/python3.9/site-packages/torch_sparse/storage.py", line 133
        if not is_sorted:
            idx = self._col.new_zeros(self._col.numel() + 1)
            idx[1:] = self.row()
                      ~~~~~~~~ <--- HERE
            idx[1:] *= self._sparse_sizes[1]
            idx[1:] += self._col
Serialized   File "code/__torch__/torch_sparse/storage.py", line 284
      _32 = torch.add(torch.numel(self._col), 1)
      idx = torch.new_zeros(_31, [_32])
      _33 = (self).row()
      ~~~~~~~~~~~~~~~~~ <--- HERE
      _34 = torch.copy_(torch.slice(idx, 0, 1), _33)
      _35 = torch.mul_(torch.slice(idx, 0, 1), (self._sparse_sizes)[1])

Aborted (core dumped)

I used the following versions for each library:

  • Pytorch and LibTorch: 1.9.0 (Pre-cxx11 ABI)
  • Pytorch Geometric: 2.0.3
  • Pytorch Scatter: 2.0.7
  • Pytorch Sparse: 0.6.10
  • Cuda: 10.2
  • Cudnn: 8.2.2

I had to downgrade some libraries because of problems with half-precision float support (Link to issue in Torch Scatter). The output from Cmake also shows no problems with finding the packages:

Output Cmake

-- The C compiler identification is GNU 7.5.0
-- The CXX compiler identification is GNU 7.5.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Failed
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE  
-- Found CUDA: /usr/local/cuda (found version "10.2") 
-- Caffe2: CUDA detected: 10.2
-- Caffe2: CUDA nvcc is: /usr/local/cuda/bin/nvcc
-- Caffe2: CUDA toolkit directory: /usr/local/cuda
-- Caffe2: Header version is: 10.2
-- Found CUDNN: /usr/local/cuda/lib64/libcudnn.so  
-- Found cuDNN: v8.2.2  (include: /usr/local/cuda/include, library: /usr/local/cuda/lib64/libcudnn.so)
-- /usr/local/cuda/lib64/libnvrtc.so shorthash is 08c4863f
-- Autodetected CUDA architecture(s):  7.5
-- Added CUDA NVCC flags for: -gencode;arch=compute_75,code=sm_75
-- Found Torch: /home/jesse/thesis_files/libtorch/lib/libtorch.so  
-- Found Python3: /home/jesse/anaconda3/envs/dtkc_rl_19/lib/libpython3.9.so (found version "3.9") found components: Development 
-- Configuring done
-- Generating done
-- Build files have been written to: /home/jesse/CLionProjects/pytorchTest/build

Hopefully, this is enough information for you; otherwise, please let me know what kind of information you need such that I can add the missing information to the post.

@btglr
Copy link

btglr commented Feb 24, 2022

Same issue here with a Makefile project instead of CMake.

libc++abi: terminating with uncaught exception of type torch::jit::ErrorReport: 
Unknown builtin op: torch_sparse::ptr2ind.
Could not find any similar ops to torch_sparse::ptr2ind. This op may not exist or may not be currently supported in TorchScript.
:
  File "/opt/homebrew/Caskroom/miniforge/base/envs/work_env/lib/python3.8/site-packages/torch_sparse/storage.py", line 179
        rowptr = self._rowptr
        if rowptr is not None:
            row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
                  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            self._row = row
            return row
Serialized   File "code/__torch__/torch_sparse/storage.py", line 832
        rowptr = unchecked_cast(Tensor, _rowptr)
        _col = self._col
        row = ops.torch_sparse.ptr2ind(rowptr, torch.numel(_col))
              ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        self._row = row
        _110 = row
'SparseStorage.row' is being compiled since it was called from 'SparseStorage.__init__'
  File "/opt/homebrew/Caskroom/miniforge/base/envs/work_env/lib/python3.8/site-packages/torch_sparse/storage.py", line 148
        if not is_sorted:
            idx = self._col.new_zeros(self._col.numel() + 1)
            idx[1:] = self.row()
                      ~~~~~~~~ <--- HERE
            idx[1:] *= self._sparse_sizes[1]
            idx[1:] += self._col
Serialized   File "code/__torch__/torch_sparse/storage.py", line 346
      _39 = [torch.add(torch.numel(_col0), 1)]
      idx = torch.new_zeros(_col, _39)
      _40 = (self).row()
      ~~~~~~~~~~~~~~~~~ <--- HERE
      _41 = torch.copy_(torch.slice(idx, 0, 1), _40)
      _42 = torch.slice(idx, 0, 1)
  • pytorch/libtorch 1.10.0
  • pytorch_sparse compiled from master

@rusty1s
Copy link
Member

rusty1s commented Feb 25, 2022

Thanks for the report and sorry for the delay in fixing this. @mananshah99 and I will look into this ASAP.

@Keagel
Copy link

Keagel commented Mar 15, 2022

Do you have an ETA on a possible fix? I'm experiencing the exact same issue and I'm stuck on my work until I can load models made with PyG.

@rusty1s
Copy link
Member

rusty1s commented Mar 15, 2022

I try to look into this in the remainder of this week, but any help and assistance here is highly appreciated.

@btglr
Copy link

btglr commented Mar 18, 2022

After fiddling for a while with my Makefile, I tried with a CMake project and started off using this CMakeLists.txt:

cmake_minimum_required(VERSION 3.20)
project(ptr_test CXX)

set(CMAKE_CXX_STANDARD 14)

find_package(TorchSparse REQUIRED)

add_executable(ptr_test test_ptr.cpp)
target_link_libraries(ptr_test TorchSparse::TorchSparse)

Trying with that file, I realized that the torchspace include directory was never actually added in my compiler switches. I'm not sure if this part is an issue on my side or if it's something to do with the Config and Targets files provided in the cmake folder. Anyway, I had to tell CMake where to find my installation of torchsparse like so:

cmake_minimum_required(VERSION 3.20)
project(ptr_test CXX)

set(CMAKE_CXX_STANDARD 14)

find_package(TorchSparse REQUIRED)

# The following two lines were required for me
include_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/include)
link_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/lib)

add_executable(ptr_test test_ptr.cpp)
target_link_libraries(ptr_test TorchSparse::TorchSparse)

Which left me with a second error:

libc++abi: terminating with uncaught exception of type torch::jit::ErrorReport: 
Unknown builtin op: torch_scatter::segment_sum_csr.
Could not find any similar ops to torch_scatter::segment_sum_csr. This op may not exist or may not be currently supported in TorchScript.
:
  File "/opt/homebrew/Caskroom/miniforge/base/envs/work_env/lib/python3.8/site-packages/torch_scatter/segment_csr.py", line 8
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
                    out: Optional[torch.Tensor] = None) -> torch.Tensor:
    return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Serialized   File "code/__torch__/torch_scatter/segment_csr.py", line 40
    indptr: Tensor,
    out: Optional[Tensor]=None) -> Tensor:
  _11 = ops.torch_scatter.segment_sum_csr(src, indptr, out)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  return _11
def segment_mean_csr(src: Tensor,
'segment_sum_csr' is being compiled since it was called from 'segment_csr'
Serialized   File "code/__torch__/torch_scatter/segment_csr.py", line 10
    out: Optional[Tensor]=None,
    reduce: str="sum") -> Tensor:
  _1 = __torch__.torch_scatter.segment_csr.segment_sum_csr
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  _2 = __torch__.torch_scatter.segment_csr.segment_mean_csr
  _3 = __torch__.torch_scatter.segment_csr.segment_min_csr
'segment_csr' is being compiled since it was called from 'SparseStorage.coalesce'
Serialized   File "code/__torch__/torch_sparse/storage.py", line 463
  def coalesce(self: __torch__.torch_sparse.storage.SparseStorage,
    reduce: str="add") -> __torch__.torch_sparse.storage.SparseStorage:
    _58 = __torch__.torch_scatter.segment_csr.segment_csr
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _col = self._col
    _col3 = self._col

So I did the same thing with TorchScatter:

cmake_minimum_required(VERSION 3.20)
project(ptr_test CXX)

set(CMAKE_CXX_STANDARD 14)

find_package(TorchSparse REQUIRED)
find_package(TorchScatter REQUIRED)

# The following two lines were required for me
include_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/include)
link_directories(/Users/mac-pro/Documents/deps/installs/arm64/libtorch-extensions/lib)

add_executable(ptr_test test_ptr.cpp)
target_link_libraries(ptr_test TorchSparse::TorchSparse)
target_link_libraries(ptr_test TorchScatter::TorchScatter)

Now both errors are gone.

@rusty1s
Copy link
Member

rusty1s commented Mar 21, 2022

I can indeed reproduce the issue. This torchvision issue (+ fix) seems to be related: pytorch/vision#2915. As far as I can see, the fix involves making use of the TORCH_LIBRARY dispatcher.

Related as well: https://stackoverflow.com/questions/65705160/torch-vision-c-interface-error-unknown-builtin-op-torchvisionnms

@rusty1s
Copy link
Member

rusty1s commented Mar 21, 2022

I finally fixed this issue, both for torch-scatter and torch-sparse:

Finally, I added a fully-working "PyG in C++" example to pytorch_geometric/examples/cpp. This example saves a "jittable" GNN model in Python, and loads and executes it in C++.

Thanks for all the help and sorry that it took me so long to fix :) Hope that all issues are now resolved!

@rusty1s rusty1s closed this as completed Mar 21, 2022
@btglr
Copy link

btglr commented Mar 22, 2022

Can confirm it's fixed for me. I was able to remove the include_directories and link_directories lines. Thank you!

@Keagel
Copy link

Keagel commented May 5, 2022

I finally fixed this issue, both for torch-scatter and torch-sparse:

* Fixes for `torch-scatter`: [CMake Fixes rusty1s/pytorch_scatter#278](https://github.com/rusty1s/pytorch_scatter/pull/278)

* Fixes for `torch-sparse`: [CMake fixes rusty1s/pytorch_sparse#212](https://github.com/rusty1s/pytorch_sparse/pull/212)

Finally, I added a fully-working "PyG in C++" example to pytorch_geometric/examples/cpp. This example saves a "jittable" GNN model in Python, and loads and executes it in C++.

Thanks for all the help and sorry that it took me so long to fix :) Hope that all issues are now resolved!

Sorry, I know this issue is closed but for some reason I can't get it to work using a Makefile. CMake works fine but I still get the same error when I try to compile using a Makefile. Do you have a simple working example with a Makefile?

@rusty1s
Copy link
Member

rusty1s commented May 6, 2022

I don't have a Makefile for you, but if you post yours here, there might be people that can help you out.

@btglr
Copy link

btglr commented May 6, 2022

I finally fixed this issue, both for torch-scatter and torch-sparse:

* Fixes for `torch-scatter`: [CMake Fixes rusty1s/pytorch_scatter#278](https://github.com/rusty1s/pytorch_scatter/pull/278)

* Fixes for `torch-sparse`: [CMake fixes rusty1s/pytorch_sparse#212](https://github.com/rusty1s/pytorch_sparse/pull/212)

Finally, I added a fully-working "PyG in C++" example to pytorch_geometric/examples/cpp. This example saves a "jittable" GNN model in Python, and loads and executes it in C++.
Thanks for all the help and sorry that it took me so long to fix :) Hope that all issues are now resolved!

Sorry, I know this issue is closed but for some reason I can't get it to work using a Makefile. CMake works fine but I still get the same error when I try to compile using a Makefile. Do you have a simple working example with a Makefile?

If you get the same error you likely forgot to link the folder where your torch_scatter and torch_sparse libraries reside.

@bybeye
Copy link

bybeye commented Feb 8, 2024

I finally fixed this issue, both for torch-scatter and torch-sparse:

Finally, I added a fully-working "PyG in C++" example to pytorch_geometric/examples/cpp. This example saves a "jittable" GNN model in Python, and loads and executes it in C++.

Thanks for all the help and sorry that it took me so long to fix :) Hope that all issues are now resolved!

When I ran the cpp examples, I got under errors:

 terminate called after throwing an instance of 'torch::jit::ErrorReport'
  what():
Unknown builtin op: pyg::index_sort.
Could not find any similar ops to pyg::index_sort. This op may not exist or may not be currently supported in TorchScript.
:
  File "code/__torch__/pyg_lib/ops.py", line 11
    _1 = (_2, _3)
  else:
    _4, _5 = ops.pyg.index_sort(inputs, max_value)
             ~~~~~~~~~~~~~~~~~~ <--- HERE
    _1 = (_4, _5)
  return _1
'index_sort' is being compiled since it was called from 'index_sort'
  File "/home/edward/.local/lib/python3.9/site-packages/torch_sparse/utils.py", line 21
    if not torch_sparse.typing.WITH_INDEX_SORT:  # pragma: no cover
        return inputs.sort()
    return pyg_lib.ops.index_sort(inputs, max_value)
           ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Serialized   File "code/__torch__/torch_sparse/utils.py", line 3
def index_sort(inputs: Tensor,
    max_value: Optional[int]=None) -> Tuple[Tensor, Tensor]:
  _0 = __torch__.pyg_lib.ops.index_sort(inputs, max_value, )
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  return _0
'index_sort' is being compiled since it was called from 'SparseStorage.__init__'
Serialized   File "code/__torch__/torch_sparse/storage.py", line 25
    is_sorted: bool=False,
    trust_data: bool=False) -> NoneType:
    _0 = __torch__.torch_sparse.utils.index_sort
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _1 = uninitialized(int)
    _2 = uninitialized(Tensor)
'SparseStorage.__init__' is being compiled since it was called from 'is_sparse'
Serialized   File "code/__torch__/torch_geometric/utils/sparse.py", line 6
    _1 = True
  else:
    _2 = isinstance(src, __torch__.torch_sparse.tensor.SparseTensor)
                    ~~~ <--- HERE
    _1 = _2
  return _1
'is_sparse' is being compiled since it was called from 'GINConvJittable_7d0ee2._check_input__0'
Serialized   File "code/__torch__/GINConvJittable_7d0ee2.py", line 31
    edge_index: Tensor,
    size: Optional[Tuple[int, int]]) -> List[Optional[int]]:
    _2 = __torch__.torch_geometric.utils.sparse.is_sparse
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _3 = "Flow direction \"target_to_source\" is invalid for message propagation via `torch_sparse.SparseTensor` or `torch.sparse.Tensor`. If you really want to make use of a reverse message passing flow, pass in the transposed sparse tensor to the message passing module, e.g., `adj_t.t()`."
    _4 = "Expected \'edge_index\' to be of integer type (got \'{}\')"

@rusty1s
Copy link
Member

rusty1s commented Feb 8, 2024

Please see my reply in #8882.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests