Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
70461d9
Support NVFP4 dynamic per tensor scale
andrewor14 Sep 23, 2025
cce3b22
Improve QAT nvfp4 numerics
andrewor14 Sep 23, 2025
4d7bb2a
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 23, 2025
8764312
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 23, 2025
20c36da
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 23, 2025
18ee38a
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 23, 2025
61dd09f
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 23, 2025
cd07758
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 23, 2025
cec1acd
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 24, 2025
8519147
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 24, 2025
6585a8c
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 24, 2025
e16506d
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 24, 2025
22ec72b
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
9dbde8f
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
e446a50
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
44abeab
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
c947099
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
05117d3
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
0f09378
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
d3fcfd4
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
843cbcf
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
4569aae
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 25, 2025
bf2208c
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
6f00784
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
90bc7d4
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
1a7eae7
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
a48b7de
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
90d6af0
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
ef3682b
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
7f06046
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 26, 2025
72586a1
Update base for Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 30, 2025
fa4d9ee
Update on "Improve QAT nvfp4 numerics"
andrewor14 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,7 +1910,6 @@ def _test_quantize_api_against_ptq(
quantize_(m, QATConfig(base_config, step="prepare"), filter_fn)
out_prepared = m(*example_inputs)
prepare_sqnr = compute_error(out_prepared, out_baseline)

self.assertGreaterEqual(prepare_sqnr, target_prepare_sqnr)

# compare convert
Expand Down Expand Up @@ -2088,21 +2087,27 @@ def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):

self._test_quantize_api_against_ptq(
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
target_prepare_sqnr=12,
target_prepare_sqnr=float("inf"),
target_convert_sqnr=float("inf"),
)

@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@parametrize("use_per_tensor_scale", [True, False])
def test_qat_nvfp4(self, use_per_tensor_scale: bool):
"""
Test QAT with `NVFP4FakeQuantizeConfig`.
"""
from torchao.prototype.mx_formats import NVFP4InferenceConfig
from torchao.prototype.qat import NVFP4FakeQuantizeConfig

torch.manual_seed(self.SEED)
m = M().cuda()
baseline_model = copy.deepcopy(m)
quantize_(
baseline_model,
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
)
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
weight_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
Expand All @@ -2116,7 +2121,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
out = m(*x)
baseline_out = baseline_model(*x)
sqnr = compute_error(out, baseline_out).item()
self.assertGreater(sqnr, 24)
self.assertGreaterEqual(sqnr, float("inf"))

@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
@unittest.skipIf(
Expand Down
47 changes: 11 additions & 36 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,37 +771,13 @@ def nvfp4_quantize(
AssertionError: If input dtype is not supported, tensor size is not
divisible by block_size, tensor is not contiguous, or block_size != 16
"""
return _nvfp4_quantize(data_hp, block_size, per_tensor_scale)


class _Float8Round(torch.autograd.Function):
"""
Cast a tensor to float8 and back to float32 with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
return x.to(torch.float8_e4m3fn).to(torch.float32)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


def _nvfp4_quantize(
data_hp: torch.Tensor,
block_size: int = 16,
per_tensor_scale: Optional[torch.Tensor] = None,
skip_dtype_cast_and_packing: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
assert data_hp.dtype in (torch.bfloat16, torch.float), (
f"{data_hp.dtype} not supported"
)
assert data_hp.size(-1) % block_size == 0, "K dim must be divisible by block_size"
assert data_hp.is_contiguous(), "Only support contiguous data for now"
assert block_size == 16, "NVFP4 requires block_size=16"

orig_dtype = data_hp.dtype
orig_shape = data_hp.shape
# Convert to float32 early for consistent precision with Triton implementation
data_hp = data_hp.float().reshape(orig_shape[0], -1, block_size)
Expand All @@ -813,8 +789,10 @@ def _nvfp4_quantize(
out_scales = None
if per_tensor_scale is None:
# We are doing single level scaling
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX)
block_scale_fp32 = _Float8Round.apply(block_scale_fp8)
block_scale_fp8 = torch.clamp(block_scale, min=E4M3_EPS, max=F8E4M3_MAX).to(
torch.float8_e4m3fn
)
block_scale_fp32 = block_scale_fp8.to(torch.float32)
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
out_scales = block_scale_fp8
else:
Expand All @@ -826,8 +804,8 @@ def _nvfp4_quantize(
scaled_block_scales = block_scale_fp32 / per_tensor_scale
scaled_block_scales_fp8 = torch.clamp(
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
)
scaled_block_scales_fp32 = _Float8Round.apply(scaled_block_scales_fp8)
).to(torch.float8_e4m3fn)
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
# To apply to data
total_scale = per_tensor_scale * scaled_block_scales_fp32
Expand All @@ -836,11 +814,8 @@ def _nvfp4_quantize(

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
data_scaled = data_scaled.view(orig_shape)
if skip_dtype_cast_and_packing:
return out_scales.to(torch.float32), data_scaled.to(orig_dtype)
else:
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
return out_scales.to(torch.float8_e4m3fn), data_lp
data_lp = f32_to_f4_unpacked(data_scaled)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
return out_scales, data_lp
4 changes: 2 additions & 2 deletions torchao/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from .nvfp4 import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizer,
NVFP4FakeQuantizedLinear,
)

__all__ = [
"NVFP4FakeQuantizeConfig",
"NVFP4FakeQuantizer",
"NVFP4FakeQuantizedLinear",
]
186 changes: 152 additions & 34 deletions torchao/prototype/qat/nvfp4.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from dataclasses import dataclass
from typing import Optional

import torch

from torchao.prototype.mx_formats.nvfp4_tensor import (
_nvfp4_quantize,
NVFP4Tensor,
_addmm_nvfp4_dispatch,
per_tensor_amax_to_scale,
)
from torchao.quantization.qat import (
FakeQuantizeConfigBase,
FakeQuantizerBase,
)
from torchao.quantization.qat import FakeQuantizeConfigBase


@dataclass
Expand All @@ -23,47 +22,166 @@ class NVFP4FakeQuantizeConfig(FakeQuantizeConfigBase):
Args:
use_per_tensor_scale (bool): Whether to use two-level per-tensor fp32 scaling
after the initial fp8 (e4m3) block-wise scaling (default True)
use_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
use_triton_kernel (bool): Whether to use triton kernels during fake quantization
"""

use_per_tensor_scale: bool = True
use_swizzled_scales: bool = False
use_triton_kernel: bool = False


# TODO: support emulation on non-Blackwell GPUs
class _NVFP4QuantizedForwardFakeQuantizedBackward(torch.autograd.Function):
"""
Autograd function for NVFP4 quantization + addmm in low precision during forward,
and fake quantization in high precision during backward.
"""

@staticmethod
def forward(
ctx,
_input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
activation_config: NVFP4FakeQuantizeConfig,
weight_config: NVFP4FakeQuantizeConfig,
) -> torch.Tensor:
# quantize input activations
if activation_config.use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(_input))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = None
_input = NVFP4Tensor.to_nvfp4(
_input,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=activation_config.use_swizzled_scales,
use_triton_kernel=activation_config.use_triton_kernel,
)

# quantize weights
if weight_config.use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(weight))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = None
weight = NVFP4Tensor.to_nvfp4(
weight,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=weight_config.use_swizzled_scales,
use_triton_kernel=False,
)

# Follow `NVFP4InferenceConfig`, always use traditional construction
# for weights and set `use_triton_kernel` afterwards
weight.use_triton_kernel = weight_config.use_triton_kernel

class NVFP4FakeQuantizer(FakeQuantizerBase):
ctx.save_for_backward(_input, weight)

return _addmm_nvfp4_dispatch(
_input,
weight.t(),
None, # aten_op, not used
bias,
)

@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
_input, weight = ctx.saved_tensors
assert isinstance(_input, NVFP4Tensor)
assert isinstance(weight, NVFP4Tensor)
_input = _input.to_dtype(_input._orig_dtype)
weight = weight.to_dtype(weight._orig_dtype)
grad_input = torch.mm(grad_output, weight)
grad_weight = torch.mm(grad_output.t(), _input)
return grad_input, grad_weight, None, None, None


class NVFP4FakeQuantizedLinear(torch.nn.Linear):
"""
(Prototype) Generic module for applying NVFP4 fake quantization to a tensor, as specified in the config.
Linear module for fake quantized NVFP4 weights and/or activations.

The forward pass follows quantization and addmm numerics in `NVFP4Tensor`
in lower precision exactly, while the backward pass uses dequantize
(fake quantized) values in high precision.

Currently this is only applicable on Blackwell and future generations.
See https://github.com/pytorch/ao/issues/3102 for more details.

Example usage::

from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

base_config = NVFP4InferenceConfig()
quantize_(model, QATConfig(base_config, step="prepare"))
# Model contains `NVFP4FakeQuantizedLinear` now

train_loop(model)
quantize_(model, QATConfig(base_config, step="convert"))
# Model contains `nn.Linear` with `NVFP4Tensor` weights now
"""

def __init__(self, config: NVFP4FakeQuantizeConfig):
super().__init__()
torch._C._log_api_usage_once("torchao.quantization.qat.NVFP4FakeQuantizer")
self.config = config
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
*args,
**kwargs,
):
super().__init__(
in_features,
out_features,
bias,
*args,
**kwargs,
)
if weight_config is None:
raise ValueError("Must specify `weight_config`")
if activation_config is None:
raise ValueError("Weight only NVFP4 QAT not supported yet")
self.activation_config = activation_config
self.weight_config = weight_config

def forward(self, x: torch.Tensor) -> torch.Tensor:
block_size = 16
original_shape = x.shape
if x.dim() == 3:
batch_size = x.shape[0]
x = x.view(-1, x.shape[-1])
if self.config.use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(x))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = None
batch_size = None
fq = _NVFP4QuantizedForwardFakeQuantizedBackward.apply(
x, self.weight, self.bias, self.activation_config, self.weight_config
)
assert fq.dtype == x.dtype
if batch_size is not None:
return fq.view(batch_size, -1, fq.shape[-1])
else:
return fq

# quantize
scale, q = _nvfp4_quantize(
x,
block_size=block_size,
per_tensor_scale=per_tensor_scale,
skip_dtype_cast_and_packing=True,
@classmethod
def from_linear(
cls,
mod: torch.nn.Linear,
activation_config: Optional[NVFP4FakeQuantizeConfig] = None,
weight_config: Optional[NVFP4FakeQuantizeConfig] = None,
):
new_linear = NVFP4FakeQuantizedLinear(
mod.in_features,
mod.out_features,
mod.bias is not None,
activation_config=activation_config,
weight_config=weight_config,
device=mod.weight.device,
dtype=mod.weight.dtype,
)
if self.config.use_per_tensor_scale:
scale = scale * per_tensor_scale
assert q.dtype == x.dtype
assert scale.dtype == torch.float32

# dequantize
M, K = q.shape[0], q.shape[1]
q = q.view(M, K // block_size, block_size)
scale = scale.view(M, K // block_size, 1)
dq = q * scale
return dq.view(original_shape).to(x.dtype)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if mod.weight.device != torch.device("meta"):
new_linear.weight = mod.weight
new_linear.bias = mod.bias
return new_linear
19 changes: 18 additions & 1 deletion torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,24 @@ def _qat_config_transform(
act_config = config.activation_config
weight_config = config.weight_config
if isinstance(module, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(module, act_config, weight_config)
# TODO: rewrite this using a registration API so
# specific quantization schemes do not leak here
from torchao.prototype.qat import (
NVFP4FakeQuantizeConfig,
NVFP4FakeQuantizedLinear,
)

if isinstance(weight_config, NVFP4FakeQuantizeConfig):
assert act_config is None or isinstance(
act_config, NVFP4FakeQuantizeConfig
)
return NVFP4FakeQuantizedLinear.from_linear(
module, act_config, weight_config
)
else:
return FakeQuantizedLinear.from_linear(
module, act_config, weight_config
)
elif isinstance(module, torch.nn.Embedding):
if act_config is not None:
raise ValueError(
Expand Down
8 changes: 6 additions & 2 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,12 +444,16 @@ def _infer_fake_quantize_configs(
elif isinstance(base_config, NVFP4InferenceConfig):
if NVFP4MMConfig.DYNAMIC:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if this is a boolean, might be good to say NVFP4MMConfig.is_dynamic I think, although probably not relevant to this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, can fix separately since this is a PTQ config

act_config = NVFP4FakeQuantizeConfig(
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
use_swizzled_scales=False,
use_triton_kernel=False,
)
else:
act_config = None
weight_config = NVFP4FakeQuantizeConfig(
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale,
use_swizzled_scales=True,
use_triton_kernel=base_config.use_triton_kernel,
)
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
assert base_config.version >= 2, "Only version 2+ is supported"
Expand Down
Loading
Loading