Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import tempfile
import unittest

import torch
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)

from torchao import quantize_
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
PerGroup,
PerRow,
PerTensor,
)
from torchao.quantization.utils import compute_error
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.utils import (
torch_version_at_least,
)


def get_config(granularity):
return Float8DynamicActivationFloat8WeightConfig(
activation_dtype=torch.float8_e4m3fn,
granularity=granularity,
float8_packing_format="opaque",
)


@common_utils.instantiate_parametrized_tests
class TestFloat8OpaqueTensor(TestCase):
"""Test cases for Float8OpaqueTensor on CPU"""

def setUp(self):
torch.set_grad_enabled(False)

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [1, 160])
@common_utils.parametrize(
"x_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
@common_utils.parametrize(
"w_granularity",
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
)
def test_dynamic_float8_linear(
self, dtype, x_dim, bias, bs, x_granularity, w_granularity
):
if isinstance(x_granularity, PerGroup):
if not isinstance(w_granularity, PerGroup):
return
if w_granularity.group_size != x_granularity.group_size:
return
device = "cpu"
m = ToyTwoLinearModel(256, 256, 256, dtype, device, bias).eval()
example_inputs = m.example_inputs(batch_size=bs)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

quantize_(
m,
get_config([x_granularity, w_granularity]),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
@common_utils.parametrize("bs", [4, 128])
def test_dynamic_float8_linear_fallback_path(self, dtype, x_dim, bias, bs):
"""
Test the fallback implementation with a shape that is not supported by the optimized kernel
"""
device = "cpu"
m = ToyTwoLinearModel(120, 120, 120, dtype, device, bias).eval()
example_inputs = m.example_inputs(batch_size=bs)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)
y = m(*example_inputs)

quantize_(
m,
get_config(PerRow()),
)
y1 = m(*example_inputs)
assert compute_error(y, y1) > 20
y2, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
assert compute_error(y, y2) > 20

@unittest.skipIf(
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
reason="cpp kernels not built",
)
@common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_module_path(self, dtype):
linear = torch.nn.Linear(128, 256, dtype=dtype)
quantize_(linear, get_config(PerRow()))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)

with tempfile.NamedTemporaryFile() as f:
torch.save(linear.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
self.assertEqual(
str(type(state_dict["weight"])),
"<class 'torchao.quantization.Float8OpaqueTensor'>",
)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
quantize_affine,
)
from .quantize_.workflows import (
Float8OpaqueTensor,
Float8Tensor,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Expand Down Expand Up @@ -174,6 +175,7 @@
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Int4OpaqueTensor",
"Float8OpaqueTensor",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down
108 changes: 68 additions & 40 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
KernelPreference,
)
from torchao.quantization.quantize_.workflows import (
Float8OpaqueTensor,
Float8PackingFormat,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Expand Down Expand Up @@ -1774,14 +1776,23 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
kernel_preference: KernelPreference = KernelPreference.AUTO
set_inductor_config: bool = True
version: int = 2
float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
)
activation_granularity, weight_granularity = _normalize_granularity(
self.granularity
)
if (
self.version == 2
and self.float8_packing_format == Float8PackingFormat.OPAQUE
):
activation_granularity, weight_granularity = (
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
)
else:
activation_granularity, weight_granularity = _normalize_granularity(
self.granularity
)
self.granularity = [activation_granularity, weight_granularity]

default_use_fast_accum = True
Expand Down Expand Up @@ -1811,44 +1822,48 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
activation_value_lb = config.activation_value_lb
activation_value_ub = config.activation_value_ub
kernel_preference = config.kernel_preference
float8_packing_format = config.float8_packing_format

# Ensure works on device
_check_hardware_support(granularity)
activation_granularity, weight_granularity = granularity

# Note: right now we assume it's weights of conv2d and conv3d purely based
# on the dimension of weight, currently there is no conflict with linear 2d
# and moe weights 3d
# if we need to support conv1d, which also has 3d weight, we may have to
# pass around the module as well to distinguish between conv1d and 3d moe weight
if weight.dim() in [4, 5]:
# weights for conv2d or 3d
assert isinstance(activation_granularity, PerTensor) and isinstance(
weight_granularity, PerTensor
), "4D/5D tensor only supports per tensor activation and weight quantization"

# conv3d weight dim: (C_out, C_in, K1, K2, K3)
# conv2d weight dim: (C_out, C_in, K1, K2)
# skip quantization when either C_out or C_in
# is not a multiple of 16
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
return weight
if float8_packing_format == Float8PackingFormat.PLAIN:
# Note: right now we assume it's weights of conv2d and conv3d purely based
# on the dimension of weight, currently there is no conflict with linear 2d
# and moe weights 3d
# if we need to support conv1d, which also has 3d weight, we may have to
# pass around the module as well to distinguish between conv1d and 3d moe weight
if weight.dim() in [4, 5]:
# weights for conv2d or 3d
assert isinstance(activation_granularity, PerTensor) and isinstance(
weight_granularity, PerTensor
), (
"4D/5D tensor only supports per tensor activation and weight quantization"
)

elif not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
# conv2d weight dim: (C_out, C_in, K1, K2)
# skip quantization when either C_out or C_in
# is not a multiple of 16
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
return weight

if isinstance(weight_granularity, PerRow):
assert weight.dtype == torch.bfloat16, (
"PerRow quantization only works for bfloat16 precision input weight"
)
elif not _fp8_mm_compat(weight):
# TODO(future PR): this should really throw an exception instead of silently
# not doing what the user asked
return weight

if isinstance(weight_granularity, PerRow):
assert weight.dtype == torch.bfloat16, (
"PerRow quantization only works for bfloat16 precision input weight"
)

if config.version == 1:
warnings.warn(
"Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
)

_check_hardware_support(granularity)
block_size = get_block_size(weight.shape[-2:], weight_granularity)
if weight.dim() == 3:
block_size = tuple([1] + list(block_size))
Expand Down Expand Up @@ -1879,14 +1894,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
kernel_preference=kernel_preference,
)

quantized_weight = Float8Tensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
act_quant_kwargs=act_quant_kwargs,
)
if float8_packing_format == Float8PackingFormat.PLAIN:
quantized_weight = Float8Tensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
mm_config=mm_config,
kernel_preference=kernel_preference,
act_quant_kwargs=act_quant_kwargs,
)
elif float8_packing_format == Float8PackingFormat.OPAQUE:
block_size = get_block_size(weight.shape, weight_granularity)
quantized_weight = Float8OpaqueTensor.from_hp(
weight,
block_size=block_size,
act_quant_kwargs=act_quant_kwargs,
)
else:
raise ValueError(
f"Unsupported float8 packing format: {float8_packing_format}"
)

return quantized_weight

Expand All @@ -1898,9 +1925,10 @@ def _float8_dynamic_activation_float8_weight_transform(
*,
parameter_name: str = "weight",
):
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
)
if config.float8_packing_format == Float8PackingFormat.PLAIN:
assert is_sm_at_least_89() or is_MI300(), (
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
)
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .float8.float8_opaque_tensor import (
Float8OpaqueTensor,
)
from .float8.float8_packing_format import Float8PackingFormat
from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -37,7 +41,9 @@
"Int4MarlinSparseTensor",
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Float8OpaqueTensor",
"Float8Tensor",
"Float8PackingFormat",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
Expand Down
Loading
Loading