Skip to content

Commit

Permalink
Avoid errors in DeepSpeed build on ROCm
Browse files Browse the repository at this point in the history
  • Loading branch information
rraminen committed Aug 28, 2024
1 parent 1041c8a commit 2e1d1ce
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions op_builder/gds.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@ def is_compatible(self, verbose=False):
self.warning("Please install torch if trying to pre-compile GDS")
return False

CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME
CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64")
if not self.is_rocm_pytorch():
CUDA_HOME = torch.utils.cpp_extension.CUDA_HOME
CUDA_LIB64 = os.path.join(CUDA_HOME, "lib64")
else:
CUDA_HOME = torch.utils.cpp_extension.ROCM_HOME
CUDA_LIB64 = os.path.join(CUDA_HOME, "lib")

gds_compatible = self.has_function(funcname="cuFileDriverOpen",
libraries=("cufile", ),
library_dirs=(
Expand Down

0 comments on commit 2e1d1ce

Please sign in to comment.