Skip to content

Commit a44ff89

Browse files
committed
Fixes MX formats build for blackwell
1 parent 554cb60 commit a44ff89

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

setup.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ def get_extensions():
395395
cuda_arch_flags = _get_cuda_arch_flags()
396396
build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
397397
build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
398+
build_for_sm100 = "-gencode=arch=compute_100,code=sm_100" in cuda_arch_flags
399+
build_for_sm100a = "-gencode=arch=compute_100a,code=sm_100a" in cuda_arch_flags
398400
if build_for_sm90 and not build_for_sm90a:
399401
cutlass_90a_sources = [
400402
os.path.join(
@@ -418,6 +420,17 @@ def get_extensions():
418420
)
419421
)
420422
sources = [s for s in sources if s not in cutlass_90a_sources]
423+
424+
if build_for_sm100 and not build_for_sm100a:
425+
cutlass_100a_sources = [
426+
os.path.join(
427+
extensions_cuda_dir,
428+
"mx_kernels",
429+
"mx_fp_cutlass_kernels.cu",
430+
),
431+
]
432+
sources = [s for s in sources if s not in cutlass_100a_sources]
433+
421434
else:
422435
# Remove CUTLASS-based kernels from the sources list. An
423436
# assumption is that these files will have "cutlass" in its
@@ -448,14 +461,29 @@ def get_extensions():
448461
)
449462
ext_modules.append(
450463
extension(
451-
"torchao._C",
464+
"torchao._C_cutlass_90a",
452465
cutlass_90a_sources,
453466
py_limited_api=True,
454467
extra_compile_args=cutlass_90a_extra_compile_args,
455468
extra_link_args=extra_link_args,
456469
)
457470
)
458471

472+
if cutlass_100a_sources is not None and len(cutlass_100a_sources) > 0:
473+
cutlass_100a_extra_compile_args = copy.deepcopy(extra_compile_args)
474+
cutlass_100a_extra_compile_args["nvcc"].extend(
475+
cuda_arch_flags + ["-gencode=arch=compute_100a,code=sm_100a"]
476+
)
477+
ext_modules.append(
478+
extension(
479+
"torchao._C_cutlass_100a",
480+
cutlass_100a_sources,
481+
py_limited_api=True,
482+
extra_compile_args=cutlass_100a_extra_compile_args,
483+
extra_link_args=extra_link_args,
484+
)
485+
)
486+
459487
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
460488
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
461489
build_options = BuildOptions()

torchao/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
so_files = list(Path(__file__).parent.glob("_C*.so"))
2727
if len(so_files) > 0:
28-
assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}"
29-
torch.ops.load_library(str(so_files[0]))
28+
for file in so_files:
29+
torch.ops.load_library(str(file))
3030
from . import ops
3131

3232
# The following library contains CPU kernels from torchao/experimental

0 commit comments

Comments
 (0)