Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove torch.cuda.is_available() check when compiling ops #3085

Merged
merged 12 commits into from
Apr 19, 2023

Conversation

jinzhen-lin
Copy link
Contributor

torch.cuda.is_available() is not necessary here. And I would cause #2858 when compiling deepspeed >= 0.8.1 on a machine without gpu (e.g. docker image build).

@tjruwase
Copy link
Contributor

@jinzhen-lin, thanks for your contribution. But can you please provide some more details on the issue fixed by this PR? In my experience, the commented code works fine on machines without gpu including this CI. Thanks!

@jinzhen-lin
Copy link
Contributor Author

jinzhen-lin commented Mar 24, 2023

@tjruwase Exactly I mean compiling cuda ops on a machine without gpu. But the CI doesn't build ops.

In the mentioned issue, we encountered an error since the quantizer op (introduced in v0.8.1) need cuda half operators, but the compilation arguments -D__CUDA_NO_HALF_OPERATORS__, -D__CUDA_NO_HALF_CONVERSIONS__, -D__CUDA_NO_BFLOAT16_CONVERSIONS__, and -D__CUDA_NO_HALF2_OPERATORS__ are set. (You can search those arguments on the mentioned issue page).

So we need those nvcc arguments :

args += [
'-allow-unsupported-compiler' if sys.platform == "win32" else '',
'--use_fast_math',
'-std=c++17'
if sys.platform == "win32" and cuda_major > 10 else '-std=c++14',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'-U__CUDA_NO_HALF2_OPERATORS__'
]

But those arguments are ignored since we do the cuda check here

def builder(self):
self.build_for_cpu = not assert_no_cuda_mismatch(self.name)
if self.build_for_cpu:
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
else:
from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder
compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \
{'cxx': self.strip_empty_entries(self.cxx_args()), \
'nvcc': self.strip_empty_entries(self.nvcc_args())}

The cuda check doesn't pass since we cannot get the true cuda version with installed_cuda_version. With this PR, we get the true cuda version and the issue should be fixed.

I think installed_cuda_version should always return the installed cuda toolkit version on the system, it should work even on a machine without gpu but with cudatoolkit.

@tjruwase
Copy link
Contributor

tjruwase commented Mar 24, 2023

@jinzhen-lin, thanks for your helpful explanation. It seems the problem is that we assume that build and target environments are the same. We recently started enabling DeepSpeed for CPU-only target environments, and we distinguish from GPU target environments by testing for GPU availability using torch.cuda.is_available(). It is now clear that our approach does not work for your scenario where you are building CUDA OPs in environment with CUDA libraries but no GPUs. The problem with this PR is that it will break builds for CPU-only environments. It seems a more robust solution is cross-compilation, and a key challenge would be enabling users to conveniently specify the target environments, implicitly or explicitly.

Please share your thoughts on this. Thanks!

@jeffra, @mrwyattii FYI

@jinzhen-lin
Copy link
Contributor Author

@tjruwase Sorry for absence of cpu builds checking before PR.

I notice that the cpu-only target environments was introduced recently (after v0.8.0) and deepspeed is mainly for gpu now. So we should always assume user want a cuda build, and we should do a cpu build when:

  • we cannot get the cuda in the build environment or cuda version is incompatible with torch cuda version
  • user specify a environment variable (e.g. DS_BUILD_OPS_CPU)

@jinzhen-lin
Copy link
Contributor Author

@microsoft-github-policy-service agree

@tjruwase
Copy link
Contributor

@jinzhen-lin, thanks for updating the PR. This is an improvement but not quite cross-compilation. Nevertheless, this will suffice for now.

jeffra added a commit that referenced this pull request Apr 18, 2023
@jeffra jeffra added the merge-queue PRs ready to merge label Apr 18, 2023
@loadams loadams linked an issue Apr 18, 2023 that may be closed by this pull request
@loadams loadams enabled auto-merge (squash) April 18, 2023 22:49
@jeffra jeffra disabled auto-merge April 19, 2023 00:45
@jeffra jeffra merged commit 036c5d6 into microsoft:master Apr 19, 2023
@conglongli conglongli added deepspeed-chat Related to DeepSpeed-Chat and removed deepspeed-chat Related to DeepSpeed-Chat labels Apr 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge-queue PRs ready to merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Compilation error for 0.8.1 with CUDA 11.2
5 participants