Skip to content

Commit ed2b959

Browse files
committed
Manually specify flags if no arch set
stack-info: PR: #2219, branch: drisspg/stack/55
1 parent 5549da8 commit ed2b959

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

CUDA_ARCH_NOTES.md

Whitespace-only changes.

setup.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,35 @@ def use_debug_mode():
7979
_get_cuda_arch_flags,
8080
)
8181

82+
# =====================================================================================
83+
# CUDA Architecture Settings
84+
# =====================================================================================
85+
# If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
86+
# detect architectures from available GPUs. This can fail when:
87+
# 1. No GPU is visible to PyTorch
88+
# 2. CUDA is available but no device is detected
89+
#
90+
# To resolve this, you can manually set CUDA architecture targets:
91+
# export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
92+
#
93+
# Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
94+
# =====================================================================================
95+
if "TORCH_CUDA_ARCH_LIST" not in os.environ and torch.version.cuda:
96+
# Set to common architectures for CUDA 12.x compatibility
97+
cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
98+
99+
# Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
100+
cuda_version = torch.version.cuda
101+
if cuda_version and cuda_version.startswith("12.8"):
102+
print("Detected CUDA 12.8 - adding SM10.0 architectures to build list")
103+
cuda_arch_list += ";10.0"
104+
105+
# Add PTX to the last architecture for future compatibility
106+
cuda_arch_list += "+PTX"
107+
108+
os.environ["TORCH_CUDA_ARCH_LIST"] = cuda_arch_list
109+
print(f"Setting default TORCH_CUDA_ARCH_LIST={os.environ['TORCH_CUDA_ARCH_LIST']}")
110+
82111
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
83112

84113

0 commit comments

Comments
 (0)