Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -35,6 +36,8 @@
F32_EXP_BIAS,
)

logger = logging.getLogger(__name__)


def get_bits(x: torch.Tensor) -> str:
bits_per_byte = 8
Expand Down Expand Up @@ -1476,10 +1479,20 @@ def triton_quantize_nvfp4(
raise AssertionError("needs torch version 2.8+ and triton")


# MXFP8 CUDA kernel is only built on SM100+
mxfp8_cuda_extension_available = False
if is_sm_at_least_100():
from torchao.prototype import mxfp8_cuda

try:
# MXFP8 CUDA kernel is only built on SM100+. Furthermore,
# currently our CI runners are not SM100+, so the user needs to build
# from source.
# TODO(#2932): improve this
from torchao.prototype import mxfp8_cuda

mxfp8_cuda_extension_available = True
except ImportError:
logging.debug("Skipping import of torchao.prototype.mxfp8_cuda")

if mxfp8_cuda_extension_available:
# TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string.
# Currently we have to use an arbitrary string because custom ops don't support enum
# params.
Expand Down Expand Up @@ -1599,4 +1612,6 @@ def mxfp8_quantize_cuda(
colwise: bool = True,
scaling_mode: str = "floor",
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError("needs torch version 2.8+ and sm100")
raise NotImplementedError(
"`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."
)
Loading