@@ -79,6 +79,35 @@ def use_debug_mode():
79
79
_get_cuda_arch_flags ,
80
80
)
81
81
82
+ # =====================================================================================
83
+ # CUDA Architecture Settings
84
+ # =====================================================================================
85
+ # If TORCH_CUDA_ARCH_LIST is not set during compilation, PyTorch tries to automatically
86
+ # detect architectures from available GPUs. This can fail when:
87
+ # 1. No GPU is visible to PyTorch
88
+ # 2. CUDA is available but no device is detected
89
+ #
90
+ # To resolve this, you can manually set CUDA architecture targets:
91
+ # export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6+PTX"
92
+ #
93
+ # Adding "+PTX" to the last architecture enables JIT compilation for future GPUs.
94
+ # =====================================================================================
95
+ if "TORCH_CUDA_ARCH_LIST" not in os .environ and torch .version .cuda :
96
+ # Set to common architectures for CUDA 12.x compatibility
97
+ cuda_arch_list = "7.0;7.5;8.0;8.6;8.9;9.0"
98
+
99
+ # Only add SM10.0 (Blackwell) flags when using CUDA 12.8 or newer
100
+ cuda_version = torch .version .cuda
101
+ if cuda_version and cuda_version .startswith ("12.8" ):
102
+ print ("Detected CUDA 12.8 - adding SM10.0 architectures to build list" )
103
+ cuda_arch_list += ";10.0"
104
+
105
+ # Add PTX to the last architecture for future compatibility
106
+ cuda_arch_list += "+PTX"
107
+
108
+ os .environ ["TORCH_CUDA_ARCH_LIST" ] = cuda_arch_list
109
+ print (f"Setting default TORCH_CUDA_ARCH_LIST={ os .environ ['TORCH_CUDA_ARCH_LIST' ]} " )
110
+
82
111
IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
83
112
84
113
0 commit comments