Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove torch.cuda.is_available() check when compiling ops #3085

Merged
merged 12 commits into from
Apr 19, 2023
27 changes: 16 additions & 11 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@


def installed_cuda_version(name=""):
import torch.cuda
if not torch.cuda.is_available():
return 0, 0
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
Expand Down Expand Up @@ -78,8 +75,6 @@ def get_default_compute_capabilities():

def assert_no_cuda_mismatch(name=""):
cuda_major, cuda_minor = installed_cuda_version(name)
if cuda_minor == 0 and cuda_major == 0:
return False
sys_cuda_version = f'{cuda_major}.{cuda_minor}'
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
# This is a show-stopping error, should probably not proceed past this
Expand Down Expand Up @@ -344,10 +339,11 @@ def cpu_arch(self):

def is_cuda_enable(self):
try:
if torch.cuda.is_available():
return '-D__ENABLE_CUDA__'
except:
print(f"{WARNING} {self.name} torch.cuda is missing, only cpu ops can be compiled!")
assert_no_cuda_mismatch(self.name)
return '-D__ENABLE_CUDA__'
except BaseException:
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
"only cpu ops can be compiled!")
return '-D__DISABLE_CUDA__'
return '-D__DISABLE_CUDA__'

Expand Down Expand Up @@ -459,7 +455,11 @@ def jit_load(self, verbose=True):
raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")

if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
self.build_for_cpu = not assert_no_cuda_mismatch(self.name)
try:
assert_no_cuda_mismatch(self.name)
self.build_for_cpu = False
except BaseException:
self.build_for_cpu = True

self.jit_mode = True
from torch.utils.cpp_extension import load
Expand Down Expand Up @@ -579,7 +579,12 @@ def is_compatible(self, verbose=True):
return super().is_compatible(verbose)

def builder(self):
self.build_for_cpu = not assert_no_cuda_mismatch(self.name)
try:
assert_no_cuda_mismatch(self.name)
self.build_for_cpu = False
except BaseException:
self.build_for_cpu = True

if self.build_for_cpu:
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
else:
Expand Down