@@ -55,10 +55,6 @@ def read_version(file_path="version.txt"):
5555 and platform .system () == "Darwin"
5656)
5757
58- use_cpp_avx512 = os .getenv ("USE_AVX512" , "1" ) == "1" and platform .system () == "Linux"
59-
60- from torchao .utils import TORCH_VERSION_AT_LEAST_2_7
61-
6258version_prefix = read_version ()
6359# Version is version.dev year month date if using nightlies and version if not
6460version = (
@@ -83,6 +79,8 @@ def use_debug_mode():
8379 _get_cuda_arch_flags ,
8480)
8581
82+ IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
83+
8684
8785class BuildOptions :
8886 def __init__ (self ):
@@ -257,53 +255,35 @@ def get_extensions():
257255 print (
258256 "PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
259257 )
260- if CUDA_HOME is None and torch .version .cuda :
261- print ("CUDA toolkit is not available. Skipping compilation of CUDA extensions" )
258+ if (CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda .is_available ():
259+ print (
260+ "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261+ )
262262 print (
263263 "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264264 )
265- if ROCM_HOME is None and torch .version .hip :
266- print ("ROCm is not available. Skipping compilation of ROCm extensions" )
267- print ("If you'd like to compile ROCm extensions locally please install ROCm" )
268-
269- use_cuda = torch .version .cuda and CUDA_HOME is not None
270- use_hip = torch .version .hip and ROCM_HOME is not None
271- extension = CUDAExtension if (use_cuda or use_hip ) else CppExtension
272-
273- nvcc_args = [
274- "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
275- "-O3" if not debug_mode else "-O0" ,
276- "-t=0" ,
277- "-std=c++17" ,
278- ]
279- hip_args = [
280- "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
281- "-O3" if not debug_mode else "-O0" ,
282- "-std=c++17" ,
283- ]
265+
266+ use_cuda = torch .cuda .is_available () and (
267+ CUDA_HOME is not None or ROCM_HOME is not None
268+ )
269+ extension = CUDAExtension if use_cuda else CppExtension
284270
285271 extra_link_args = []
286272 extra_compile_args = {
287273 "cxx" : [f"-DPy_LIMITED_API={ PY3_9_HEXCODE } " ],
288- "nvcc" : nvcc_args if use_cuda else hip_args ,
274+ "nvcc" : [
275+ "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
276+ "-O3" if not debug_mode else "-O0" ,
277+ "-t=0" ,
278+ "-std=c++17" ,
279+ ],
289280 }
290281
291282 if not IS_WINDOWS :
292283 extra_compile_args ["cxx" ].extend (
293284 ["-O3" if not debug_mode else "-O0" , "-fdiagnostics-color=always" ]
294285 )
295286
296- if use_cpp_avx512 and TORCH_VERSION_AT_LEAST_2_7 :
297- if torch ._C ._cpu ._is_avx512_supported ():
298- extra_compile_args ["cxx" ].extend (
299- [
300- "-DCPU_CAPABILITY_AVX512" ,
301- "-march=native" ,
302- "-mfma" ,
303- "-fopenmp" ,
304- ]
305- )
306-
307287 if debug_mode :
308288 extra_compile_args ["cxx" ].append ("-g" )
309289 if "nvcc" in extra_compile_args :
@@ -319,95 +299,48 @@ def get_extensions():
319299 extra_compile_args ["nvcc" ].append ("-g" )
320300 extra_link_args .append ("/DEBUG" )
321301
322- hip_sparse_marlin_supported = True
323- if use_hip :
324- # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
325- found_col16 = False
326- found_vec_ext = False
327- print ("ROCM_HOME" , ROCM_HOME )
328- hipblaslt_headers = list (
329- glob .glob (os .path .join (ROCM_HOME , "include" , "hipblaslt" , "hipblaslt.h" ))
330- )
331- print ("hipblaslt_headers" , hipblaslt_headers )
332- for header in hipblaslt_headers :
333- with open (header ) as f :
334- text = f .read ()
335- if "HIPBLASLT_ORDER_COL16" in text :
336- found_col16 = True
337- if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text :
338- found_vec_ext = True
339- if found_col16 :
340- extra_compile_args ["cxx" ].append ("-DHIPBLASLT_HAS_ORDER_COL16" )
341- print ("hipblaslt found extended col order enums" )
342- else :
343- print ("hipblaslt does not have extended col order enums" )
344- if found_vec_ext :
345- extra_compile_args ["cxx" ].append ("-DHIPBLASLT_VEC_EXT" )
346- print ("hipblaslt found vec ext" )
347- else :
348- print ("hipblaslt does not have vec ext" )
349-
350- # sparse_marlin depends on features in ROCm 6.4, __builtin_amdgcn_global_load_lds
351- ROCM_VERSION = tuple (int (v ) for v in torch .version .hip .split ("." )[:2 ])
352- hip_sparse_marlin_supported = ROCM_VERSION >= (6 , 4 )
353-
354302 # Get base directory and source paths
355303 curdir = os .path .dirname (os .path .curdir )
356304 extensions_dir = os .path .join (curdir , "torchao" , "csrc" )
357305
358306 # Collect C++ source files
359307 sources = list (glob .glob (os .path .join (extensions_dir , "**/*.cpp" ), recursive = True ))
360- if IS_WINDOWS :
361- # Remove csrc/cpu/*.cpp on Windows due to the link issue: unresolved external symbol PyInit__C
362- excluded_sources = list (
363- glob .glob (os .path .join (extensions_dir , "cpu/*.cpp" ), recursive = True )
364- )
365- sources = [s for s in sources if s not in excluded_sources ]
366308
367- # Collect CUDA source files
368309 extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
369310 cuda_sources = list (
370311 glob .glob (os .path .join (extensions_cuda_dir , "**/*.cu" ), recursive = True )
371312 )
372313
373- # Collect HIP source files
374314 extensions_hip_dir = os .path .join (
375315 extensions_dir , "cuda" , "tensor_core_tiled_layout"
376316 )
377317 hip_sources = list (
378318 glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
379319 )
380- if hip_sparse_marlin_supported :
381- extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" )
382- hip_sources += list (
383- glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
384- )
385- extensions_hip_dir = os .path .join (extensions_dir , "rocm" )
386- hip_sources += list (
387- glob .glob (os .path .join (extensions_hip_dir , "**/*.hip" ), recursive = True )
388- )
320+ extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" )
389321 hip_sources += list (
390- glob .glob (os .path .join (extensions_hip_dir , "**/*.cpp " ), recursive = True )
322+ glob .glob (os .path .join (extensions_hip_dir , "*.cu " ), recursive = True )
391323 )
392324
393- # Add CUDA source files if needed
394- if use_cuda :
325+ # Collect CUDA source files if needed
326+ if not IS_ROCM and use_cuda :
395327 sources += cuda_sources
396328
397- # TODO: Remove this and use what CUDA has once we fix all the builds.
398- # Add HIP source files if needed
399- if use_hip :
329+ # TOOD: Remove this and use what CUDA has once we fix all the builds.
330+ if IS_ROCM and use_cuda :
400331 # Add ROCm GPU architecture check
401332 gpu_arch = torch .cuda .get_device_properties (0 ).name
402333 if gpu_arch != "gfx942" :
403334 print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
404- print ("Currently only gfx942 is supported. Compiling only for gfx942." )
405- extra_compile_args ["nvcc" ].append ("--offload-arch=gfx942" )
406- sources += hip_sources
335+ print (
336+ "Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
337+ )
338+ else :
339+ sources += hip_sources
407340
408341 use_cutlass = False
409342 cutlass_90a_sources = None
410- if use_cuda and not IS_WINDOWS :
343+ if use_cuda and not IS_ROCM and not IS_WINDOWS :
411344 use_cutlass = True
412345 cutlass_dir = os .path .join (third_party_path , "cutlass" )
413346 cutlass_include_dir = os .path .join (cutlass_dir , "include" )
0 commit comments