@@ -55,10 +55,6 @@ def read_version(file_path="version.txt"):
55
55
and platform .system () == "Darwin"
56
56
)
57
57
58
- use_cpp_avx512 = os .getenv ("USE_AVX512" , "1" ) == "1" and platform .system () == "Linux"
59
-
60
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_7
61
-
62
58
version_prefix = read_version ()
63
59
# Version is version.dev year month date if using nightlies and version if not
64
60
version = (
@@ -83,6 +79,8 @@ def use_debug_mode():
83
79
_get_cuda_arch_flags ,
84
80
)
85
81
82
+ IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
83
+
86
84
87
85
class BuildOptions :
88
86
def __init__ (self ):
@@ -257,53 +255,35 @@ def get_extensions():
257
255
print (
258
256
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
259
257
)
260
- if CUDA_HOME is None and torch .version .cuda :
261
- print ("CUDA toolkit is not available. Skipping compilation of CUDA extensions" )
258
+ if (CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda .is_available ():
259
+ print (
260
+ "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261
+ )
262
262
print (
263
263
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264
264
)
265
- if ROCM_HOME is None and torch .version .hip :
266
- print ("ROCm is not available. Skipping compilation of ROCm extensions" )
267
- print ("If you'd like to compile ROCm extensions locally please install ROCm" )
268
-
269
- use_cuda = torch .version .cuda and CUDA_HOME is not None
270
- use_hip = torch .version .hip and ROCM_HOME is not None
271
- extension = CUDAExtension if (use_cuda or use_hip ) else CppExtension
272
-
273
- nvcc_args = [
274
- "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
275
- "-O3" if not debug_mode else "-O0" ,
276
- "-t=0" ,
277
- "-std=c++17" ,
278
- ]
279
- hip_args = [
280
- "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
281
- "-O3" if not debug_mode else "-O0" ,
282
- "-std=c++17" ,
283
- ]
265
+
266
+ use_cuda = torch .cuda .is_available () and (
267
+ CUDA_HOME is not None or ROCM_HOME is not None
268
+ )
269
+ extension = CUDAExtension if use_cuda else CppExtension
284
270
285
271
extra_link_args = []
286
272
extra_compile_args = {
287
273
"cxx" : [f"-DPy_LIMITED_API={ PY3_9_HEXCODE } " ],
288
- "nvcc" : nvcc_args if use_cuda else hip_args ,
274
+ "nvcc" : [
275
+ "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
276
+ "-O3" if not debug_mode else "-O0" ,
277
+ "-t=0" ,
278
+ "-std=c++17" ,
279
+ ],
289
280
}
290
281
291
282
if not IS_WINDOWS :
292
283
extra_compile_args ["cxx" ].extend (
293
284
["-O3" if not debug_mode else "-O0" , "-fdiagnostics-color=always" ]
294
285
)
295
286
296
- if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7 :
297
- if torch ._C ._cpu ._is_avx512_supported ():
298
- extra_compile_args ["cxx" ].extend (
299
- [
300
- "-DCPU_CAPABILITY_AVX512" ,
301
- "-march=native" ,
302
- "-mfma" ,
303
- "-fopenmp" ,
304
- ]
305
- )
306
-
307
287
if debug_mode :
308
288
extra_compile_args ["cxx" ].append ("-g" )
309
289
if "nvcc" in extra_compile_args :
@@ -319,95 +299,48 @@ def get_extensions():
319
299
extra_compile_args ["nvcc" ].append ("-g" )
320
300
extra_link_args .append ("/DEBUG" )
321
301
322
- hip_sparse_marlin_supported = True
323
- if use_hip :
324
- # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
325
- found_col16 = False
326
- found_vec_ext = False
327
- print ("ROCM_HOME" , ROCM_HOME )
328
- hipblaslt_headers = list (
329
- glob .glob (os .path .join (ROCM_HOME , "include" , "hipblaslt" , "hipblaslt.h" ))
330
- )
331
- print ("hipblaslt_headers" , hipblaslt_headers )
332
- for header in hipblaslt_headers :
333
- with open (header ) as f :
334
- text = f .read ()
335
- if "HIPBLASLT_ORDER_COL16" in text :
336
- found_col16 = True
337
- if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text :
338
- found_vec_ext = True
339
- if found_col16 :
340
- extra_compile_args ["cxx" ].append ("-DHIPBLASLT_HAS_ORDER_COL16" )
341
- print ("hipblaslt found extended col order enums" )
342
- else :
343
- print ("hipblaslt does not have extended col order enums" )
344
- if found_vec_ext :
345
- extra_compile_args ["cxx" ].append ("-DHIPBLASLT_VEC_EXT" )
346
- print ("hipblaslt found vec ext" )
347
- else :
348
- print ("hipblaslt does not have vec ext" )
349
-
350
- # sparse_marlin depends on features in ROCm 6.4, __builtin_amdgcn_global_load_lds
351
- ROCM_VERSION = tuple (int (v ) for v in torch .version .hip .split ("." )[:2 ])
352
- hip_sparse_marlin_supported = ROCM_VERSION >= (6 , 4 )
353
-
354
302
# Get base directory and source paths
355
303
curdir = os .path .dirname (os .path .curdir )
356
304
extensions_dir = os .path .join (curdir , "torchao" , "csrc" )
357
305
358
306
# Collect C++ source files
359
307
sources = list (glob .glob (os .path .join (extensions_dir , "**/*.cpp" ), recursive = True ))
360
- if IS_WINDOWS :
361
- # Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C
362
- excluded_sources = list (
363
- glob .glob (os .path .join (extensions_dir , "cpu/*.cpp" ), recursive = True )
364
- )
365
- sources = [s for s in sources if s not in excluded_sources ]
366
308
367
- # Collect CUDA source files
368
309
extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
369
310
cuda_sources = list (
370
311
glob .glob (os .path .join (extensions_cuda_dir , "**/*.cu" ), recursive = True )
371
312
)
372
313
373
- # Collect HIP source files
374
314
extensions_hip_dir = os .path .join (
375
315
extensions_dir , "cuda" , "tensor_core_tiled_layout"
376
316
)
377
317
hip_sources = list (
378
318
glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
379
319
)
380
- if hip_sparse_marlin_supported :
381
- extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" )
382
- hip_sources += list (
383
- glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
384
- )
385
- extensions_hip_dir = os .path .join (extensions_dir , "rocm" )
386
- hip_sources += list (
387
- glob .glob (os .path .join (extensions_hip_dir , "**/*.hip" ), recursive = True )
388
- )
320
+ extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" )
389
321
hip_sources += list (
390
- glob .glob (os .path .join (extensions_hip_dir , "**/*.cpp " ), recursive = True )
322
+ glob .glob (os .path .join (extensions_hip_dir , "*.cu " ), recursive = True )
391
323
)
392
324
393
- # Add CUDA source files if needed
394
- if use_cuda :
325
+ # Collect CUDA source files if needed
326
+ if not IS_ROCM and use_cuda :
395
327
sources += cuda_sources
396
328
397
- # TODO: Remove this and use what CUDA has once we fix all the builds.
398
- # Add HIP source files if needed
399
- if use_hip :
329
+ # TOOD: Remove this and use what CUDA has once we fix all the builds.
330
+ if IS_ROCM and use_cuda :
400
331
# Add ROCm GPU architecture check
401
332
gpu_arch = torch .cuda .get_device_properties (0 ).name
402
333
if gpu_arch != "gfx942" :
403
334
print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
404
- print ("Currently only gfx942 is supported. Compiling only for gfx942." )
405
- extra_compile_args ["nvcc" ].append ("--offload-arch=gfx942" )
406
- sources += hip_sources
335
+ print (
336
+ "Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
337
+ )
338
+ else :
339
+ sources += hip_sources
407
340
408
341
use_cutlass = False
409
342
cutlass_90a_sources = None
410
- if use_cuda and not IS_WINDOWS :
343
+ if use_cuda and not IS_ROCM and not IS_WINDOWS :
411
344
use_cutlass = True
412
345
cutlass_dir = os .path .join (third_party_path , "cutlass" )
413
346
cutlass_include_dir = os .path .join (cutlass_dir , "include" )
0 commit comments