Skip to content

Commit 244b7eb

Browse files
bottlerfacebook-github-bot
authored andcommitted
allow packaging tools to override CUDA settings
Summary: This makes sure circle ci builds work with cuda even on machines with no gpu. Reviewed By: gkioxari Differential Revision: D19543957 fbshipit-source-id: 9cbfcd4fca22ebe89434ffa71c25d75dd18d2eb6
1 parent 674ee44 commit 244b7eb

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

setup.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,31 @@ def get_extensions():
2323
extra_compile_args = {"cxx": ["-std=c++17"]}
2424
define_macros = []
2525

26-
if torch.cuda.is_available() and CUDA_HOME is not None:
26+
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
27+
if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda:
2728
extension = CUDAExtension
2829
sources += source_cuda
2930
define_macros += [("WITH_CUDA", None)]
30-
extra_compile_args["nvcc"] = [
31+
nvcc_args = [
3132
"-DCUDA_HAS_FP16=1",
3233
"-D__CUDA_NO_HALF_OPERATORS__",
3334
"-D__CUDA_NO_HALF_CONVERSIONS__",
3435
"-D__CUDA_NO_HALF2_OPERATORS__",
3536
]
37+
nvcc_flags_env = os.getenv("NVCC_FLAGS", "")
38+
if nvcc_flags_env != "":
39+
nvcc_args.extend(nvcc_flags_env.split(" "))
3640

3741
# It's better if pytorch can do this by default ..
3842
CC = os.environ.get("CC", None)
3943
if CC is not None:
40-
extra_compile_args["nvcc"].append("-ccbin={}".format(CC))
44+
CC_arg = "-ccbin={}".format(CC)
45+
if CC_arg not in nvcc_args:
46+
if any(arg.startswith("-ccbin") for arg in nvcc_args):
47+
raise ValueError("Inconsistent ccbins")
48+
nvcc_args.append(CC_arg)
49+
50+
extra_compile_args["nvcc"] = nvcc_args
4151

4252
sources = [os.path.join(extensions_dir, s) for s in sources]
4353

0 commit comments

Comments
 (0)