Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Building C++ API from source failed #239

Closed
juntingzh opened this issue Sep 21, 2021 · 17 comments
Closed

Building C++ API from source failed #239

juntingzh opened this issue Sep 21, 2021 · 17 comments

Comments

@juntingzh
Copy link

I have CUDA 11.2 and CUDNN 8.2.1, PyTorch 1.10 and when I try to build pytorch_scatter from source I hit this error:

[  9%] Building CXX object CMakeFiles/torchscatter.dir/csrc/cpu/scatter_cpu.cpp.o
[ 18%] Building CXX object CMakeFiles/torchscatter.dir/csrc/cpu/segment_coo_cpu.cpp.o
[ 27%] Building CXX object CMakeFiles/torchscatter.dir/csrc/cpu/segment_csr_cpu.cpp.o
[ 36%] Building CUDA object CMakeFiles/torchscatter.dir/csrc/cuda/scatter_cuda.cu.o
/home/junting_zhang/projects/pytorch_scatter/csrc/cuda/utils.cuh(12): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

/home/junting_zhang/projects/pytorch_scatter/csrc/cuda/utils.cuh(18): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

2 errors detected in the compilation of "/home/junting_zhang/projects/pytorch_scatter/csrc/cuda/scatter_cuda.cu".
CMakeFiles/torchscatter.dir/build.make:120: recipe for target 'CMakeFiles/torchscatter.dir/csrc/cuda/scatter_cuda.cu.o' failed
make[2]: *** [CMakeFiles/torchscatter.dir/csrc/cuda/scatter_cuda.cu.o] Error 1
CMakeFiles/Makefile2:94: recipe for target 'CMakeFiles/torchscatter.dir/all' failed
make[1]: *** [CMakeFiles/torchscatter.dir/all] Error 2
Makefile:148: recipe for target 'all' failed
make: *** [all] Error 2

Any idea on how to fix?

@rusty1s
Copy link
Owner

rusty1s commented Sep 21, 2021

Are you using PyTorch nightly? There might be some changes in nightly PyTorch internals that prevent us from building external packages. Those get usually fixed once PyTorch 1.10 is fully released.

@kotatsuyaki
Copy link

kotatsuyaki commented Oct 2, 2021

This is probably not caused by the usage of PyTorch nightly though, I'm getting the same error with libtorch 1.9.0+cu111. The error also happens with torch_sparse >= 0.6.11, which was the release that introduced half-precision support.

I'm temporarily reverting to using release 2.0.7, because it builds without this issue.

Build log
Scanning dependencies of target torchscatter
[  9%] Building CXX object CMakeFiles/torchscatter.dir/csrc/cpu/scatter_cpu.cpp.o
[ 18%] Building CXX object CMakeFiles/torchscatter.dir/csrc/cpu/segment_coo_cpu.cpp.o
[ 27%] Building CXX object CMakeFiles/torchscatter.dir/csrc/cpu/segment_csr_cpu.cpp.o
[ 36%] Building CUDA object CMakeFiles/torchscatter.dir/csrc/cuda/segment_coo_cuda.cu.o
[ 45%] Building CUDA object CMakeFiles/torchscatter.dir/csrc/cuda/scatter_cuda.cu.o
[ 54%] Building CUDA object CMakeFiles/torchscatter.dir/csrc/cuda/segment_csr_cuda.cu.o
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets
to suppress warning).
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets
to suppress warning).
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets
to suppress warning).
/build/source/csrc/cuda/utils.cuh(12): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

/build/source/csrc/cuda/utils.cuh(12): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

/build/source/csrc/cuda/utils.cuh(18): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

/build/source/csrc/cuda/utils.cuh(18): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

/build/source/csrc/cuda/utils.cuh(12): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

/build/source/csrc/cuda/utils.cuh(18): error: more than one user-defined conversion from "const c10::Half" to "__half" applies:
            function "__half::__half(float)"
            function "__half::__half(__half &&)"

2 errors detected in the compilation of "/build/source/csrc/cuda/scatter_cuda.cu".
2 errors detected in the compilation of "/build/source/csrc/cuda/segment_csr_cuda.cu".

@rusty1s
Copy link
Owner

rusty1s commented Oct 4, 2021

Thanks. Which GPU are you using? Can you also check if specifying a compiler flag helps to resolve this?

export TORCH_CUDA_ARCH_LIST = "7.5;8.0;8.6"

@kotatsuyaki
Copy link

kotatsuyaki commented Oct 4, 2021

It's an RTX3070. I'll try the environment variable out and report again.

@akshitgandhi
Copy link

I am also facing the same issue. Did anyone figure out a solution to this problem? I also tried latest pytorch_scatter release=2.0.9 and libtorch=1.10 with Cuda 11.1 and gcc 10, but it still fails with this same error.

@alexnikulkov
Copy link

Are there any updates here? Do we think the half-precision operations were implemented robustly or could it make sense to back them out if they are the source of error?

@rusty1s
Copy link
Owner

rusty1s commented Nov 19, 2021

I'm still not sure what's causing this TBH, and half precision support should work fine when using our provided wheels. One option would be to add a USE_HALF flag that users that want to install from source and encounter this error can use.

@alepack89
Copy link

alepack89 commented Mar 24, 2022

Hi

I don't know if this may help, but I think that the problem may be that there are two possible conversion methods from at::Half to __half.
__half can be created either from a float or a __half and since at::Half has a cast-to-float operator defined, the conversion results ambiguous

If I call the correct conversion operator explicitly the code can be built without errors

#pragma once

#include <torch/extension.h>

#define CHECK_CUDA(x)                                                          \
  AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")

__device__ __inline__ at::Half __shfl_up_sync(const unsigned mask,
                                              const at::Half var,
                                              const unsigned int delta) {
  return __shfl_up_sync(mask, var.operator __half(), delta);
}

__device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
                                                const at::Half var,
                                                const unsigned int delta) {
  return __shfl_down_sync(mask, var.operator __half(), delta);
}
``
Could this be the solution?

@rusty1s
Copy link
Owner

rusty1s commented Mar 24, 2022

Thanks for sharing this trick, integrated via #280. I am wondering why this only results in an error when building with CMake though.

@alepack89
Copy link

alepack89 commented Mar 24, 2022

Hi

Checking the make commands that are issued, I think that the CMake version is missing the following definitions

add_definitions(-D__CUDA_NO_HALF_CONVERSIONS__)
add_definitions(-D__CUDA_NO_BFLOAT16_CONVERSIONS__)
add_definitions(-D__CUDA_NO_HALF2_OPERATORS__)

Adding them results in a correct build even with the original code

@alepack89
Copy link

Hi

I think that the same problem can be found in the pytorch_sparse project.
The file is
csrc/cuda/cuda_utils.cuh

Is this the case?

Thank you

@rusty1s
Copy link
Owner

rusty1s commented Apr 14, 2022

You are absolutely right. Let me know if you want to fix this :)

@rusty1s rusty1s closed this as completed Apr 14, 2022
@alepack89
Copy link

Hi

I can do it if you want but it's ok either way :)
Glad that this helped.

By the way I've noticed a problem when building pytorch_cluster C++ api. In this case it is the linker complaining that the function
get_example_idx
contained inside the file
pytorch_cluster/csrc/cuda/uitls.cuh
appears to be defined multiple times.
Indeed adding a 'forceinline' keyword before the function declaration seems to solve the problem.
Is this a problem you can reproduce too?

@rusty1s
Copy link
Owner

rusty1s commented Apr 14, 2022

I cannot reproduce this but happy to fix this on your end.

@alepack89
Copy link

Sure

I will create pull requests for these issues
Thank you very much

@rusty1s
Copy link
Owner

rusty1s commented Apr 20, 2022

Thank you! Also added the __forceinline__ keyword here.

@alepack89
Copy link

alepack89 commented Apr 20, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants