@@ -395,6 +395,8 @@ def get_extensions():
395
395
cuda_arch_flags = _get_cuda_arch_flags ()
396
396
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
397
397
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
398
+ build_for_sm100 = "-gencode=arch=compute_100,code=sm_100" in cuda_arch_flags
399
+ build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags
398
400
if build_for_sm90 and not build_for_sm90a :
399
401
cutlass_90a_sources = [
400
402
os .path .join (
@@ -418,6 +420,17 @@ def get_extensions():
418
420
)
419
421
)
420
422
sources = [s for s in sources if s not in cutlass_90a_sources ]
423
+
424
+ if build_for_sm100 and not build_for_sm100a :
425
+ cutlass_100a_sources = [
426
+ os .path .join (
427
+ extensions_cuda_dir ,
428
+ "mx_kernels" ,
429
+ "mx_fp_cutlass_kernels.cu" ,
430
+ ),
431
+ ]
432
+ sources = [s for s in sources if s not in cutlass_100a_sources ]
433
+
421
434
else :
422
435
# Remove CUTLASS-based kernels from the sources list. An
423
436
# assumption is that these files will have "cutlass" in its
@@ -448,14 +461,29 @@ def get_extensions():
448
461
)
449
462
ext_modules .append (
450
463
extension (
451
- "torchao._C " ,
464
+ "torchao._C_cutlass_90a " ,
452
465
cutlass_90a_sources ,
453
466
py_limited_api = True ,
454
467
extra_compile_args = cutlass_90a_extra_compile_args ,
455
468
extra_link_args = extra_link_args ,
456
469
)
457
470
)
458
471
472
+ if cutlass_100a_sources is not None and len (cutlass_100a_sources ) > 0 :
473
+ cutlass_100a_extra_compile_args = copy .deepcopy (extra_compile_args )
474
+ cutlass_100a_extra_compile_args ["nvcc" ].extend (
475
+ cuda_arch_flags + ["-gencode=arch=compute_100a,code=sm_100a" ]
476
+ )
477
+ ext_modules .append (
478
+ extension (
479
+ "torchao._C_cutlass_100a" ,
480
+ cutlass_100a_sources ,
481
+ py_limited_api = True ,
482
+ extra_compile_args = cutlass_100a_extra_compile_args ,
483
+ extra_link_args = extra_link_args ,
484
+ )
485
+ )
486
+
459
487
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
460
488
if build_macos_arm_auto or os .getenv ("BUILD_TORCHAO_EXPERIMENTAL" ) == "1" :
461
489
build_options = BuildOptions ()
0 commit comments