-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Building C++ API from source failed #239
Comments
Are you using PyTorch nightly? There might be some changes in nightly PyTorch internals that prevent us from building external packages. Those get usually fixed once PyTorch 1.10 is fully released. |
This is probably not caused by the usage of PyTorch nightly though, I'm getting the same error with libtorch 1.9.0+cu111. The error also happens with torch_sparse >= 0.6.11, which was the release that introduced half-precision support. I'm temporarily reverting to using release 2.0.7, because it builds without this issue. Build log
|
Thanks. Which GPU are you using? Can you also check if specifying a compiler flag helps to resolve this?
|
It's an RTX3070. I'll try the environment variable out and report again. |
I am also facing the same issue. Did anyone figure out a solution to this problem? I also tried latest pytorch_scatter release=2.0.9 and libtorch=1.10 with Cuda 11.1 and gcc 10, but it still fails with this same error. |
Are there any updates here? Do we think the half-precision operations were implemented robustly or could it make sense to back them out if they are the source of error? |
I'm still not sure what's causing this TBH, and |
Hi I don't know if this may help, but I think that the problem may be that there are two possible conversion methods from at::Half to __half. If I call the correct conversion operator explicitly the code can be built without errors
|
Thanks for sharing this trick, integrated via #280. I am wondering why this only results in an error when building with CMake though. |
Hi Checking the make commands that are issued, I think that the CMake version is missing the following definitions add_definitions(-D__CUDA_NO_HALF_CONVERSIONS__) Adding them results in a correct build even with the original code |
Hi I think that the same problem can be found in the pytorch_sparse project. Is this the case? Thank you |
You are absolutely right. Let me know if you want to fix this :) |
Hi I can do it if you want but it's ok either way :) By the way I've noticed a problem when building pytorch_cluster C++ api. In this case it is the linker complaining that the function |
I cannot reproduce this but happy to fix this on your end. |
Sure I will create pull requests for these issues |
Thank you! Also added the |
Great!
Thank you
…________________________________
From: Matthias Fey ***@***.***>
Sent: Wednesday, April 20, 2022 11:04:56 AM
To: rusty1s/pytorch_scatter ***@***.***>
Cc: Alessandro Pacchioni ***@***.***>; Comment ***@***.***>
Subject: Re: [rusty1s/pytorch_scatter] Building C++ API from source failed (#239)
Thank you! Also added the __forceinline__ keyword here<rusty1s/pytorch_cluster@9472aef>.
—
Reply to this email directly, view it on GitHub<#239 (comment)>, or unsubscribe<https://github.com/notifications/unsubscribe-auth/AEAEVMJSGKWEOCNMONYQ3G3VF7CDRANCNFSM5ENGQFMA>.
You are receiving this because you commented.Message ID: ***@***.***>
|
I have CUDA 11.2 and CUDNN 8.2.1, PyTorch 1.10 and when I try to build pytorch_scatter from source I hit this error:
Any idea on how to fix?
The text was updated successfully, but these errors were encountered: