Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Algorithms.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- NF4 compression mode
- Arbitrary look-up table (CODEBOOK) or predefined lookup table based on NF4 (CB4_F8E4M3)
- MX-compliant types - MXFP4 and MXFP8_E4M3
- FP types - FP8_E4M3 and FP4
- Mixed precision weights compression
- Grouped weights compression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ NNCF can automatically distribute precision assignments based on quantization se
| CB4_F8E4M3 | E4M3 | FP16 | Per-channel / Group-wise | A fixed lookup table with 16 E4M3 values based on NF4 values |
| MXFP4 | E2M1 | E8M0 | Group-wise (32) | [MX-compliant FP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |
| MXFP8_E4M3 | E4M3 | E8M0 | Group-wise (32) | [MX-compliant FP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |
| FP8_E4M3 | E4M3 | FP16 | Per-channel / Group-wise | [FP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |
| FP4 | E2M1 | FP16 | Per-channel / Group-wise | [FP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) |

**Note**: Granularity refers to the scope of elements sharing quantization parameters. "Per-channel" applies different parameters for each output channel, while "Group-wise" divides weights into groups (e.g., group_size=128) that share the same parameters.

Expand Down
3 changes: 2 additions & 1 deletion src/nncf/openvino/optimized_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def do_float_quantization(
:param config: Weight compression configuration.
:param reduction_axes: Axes, along which to reduce (collect) different statistics.
:param precomputed_scale: Optional precomputed scale.
:return: Returns quantized weight tensor and corresponding scale tensor.
:return: Returns quantized (for MXFP8_E4M3, FP4 and FP8_E4M3 normalized)
weight tensor and corresponding scale tensor.
"""
assert config.mode in [CompressWeightsMode.NF4, CompressWeightsMode.MXFP4]

Expand Down
4 changes: 4 additions & 0 deletions src/nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class CompressWeightsMode(StrEnum):
:param INT8: Mode is deprecated and will be removed in future releases. Please use `INT8_ASYM` instead.
:param MXFP4: MX-compliant FP4 format with E2M1 values sharing group-level E8M0 scale. The size of group is 32.
:param MXFP8_E4M3: MX-compliant FP8 format with E4M3 values sharing group-level E8M0 scale. The size of group is 32.
:param FP8_E4M3: A FP8 format with E4M3 values sharing group-level fp16 scale.
:param FP4: A FP4 format with E2M1 values sharing group-level fp16 scale.
:param CODEBOOK: Codebook (LUT) quantization format.
:param CB4_F8E4M3: Codebook (LUT) format with 16 fixed fp8 values in E4M3 format.
"""
Expand All @@ -105,6 +107,8 @@ class CompressWeightsMode(StrEnum):
INT8 = "int8" # Deprecated mode
MXFP4 = "mxfp4"
MXFP8_E4M3 = "mxfp8_e4m3"
FP8_E4M3 = "fp8_e4m3"
FP4 = "fp4"
CODEBOOK = "codebook"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
]
SUPPORTED_DATA_TYPES = [
TensorDataType.float16,
Expand Down Expand Up @@ -300,6 +302,8 @@ def __init__(
NF4 is the same as INT4_SYM mode, but primary precision is NF4 data type without zero point.
MXFP4 is MX-compliant FP4 with E2M1 values sharing group-level E8M0 scale. The size of group is 32.
MXFP8_E4M3 is MX-compliant FP8 with E4M3 values sharing group-level E8M0 scale. The size of group is 32.
FP8_E4M3 is FP8 with E4M3 values sharing group-level FP16 scale.
FP4 is FP4 with E2M1 values sharing group-level FP16 scale.
:param ratio: the ratio between primary and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
and the rest to backup_mode).
:param group_size: number of weights (e.g. 128) in the channel dimension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def is_integer(self):
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
CompressWeightsMode.CODEBOOK,
CompressWeightsMode.CB4_F8E4M3,
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def _create_compression_subgraph(
elif compression_config.mode == CompressWeightsMode.MXFP8_E4M3:
compression_dtype = ov.Type.f8e4m3
scale_dtype = ov.Type.f8e8m0
elif compression_config.mode == CompressWeightsMode.FP8_E4M3:
compression_dtype = ov.Type.f8e4m3
elif compression_config.mode == CompressWeightsMode.FP4:
compression_dtype = ov.Type.f4e2m1
elif compression_config.mode == CompressWeightsMode.INT4_SYM:
compression_dtype = ov.Type.i4
elif compression_config.mode == CompressWeightsMode.INT4_ASYM:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ def transform_model(
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
]:
msg = f"{compression_config.mode.value} is not supported."
raise nncf.ParameterNotSupportedError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def transform_model(
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
]:
msg = f"{compression_config.mode.value} is not supported."
raise nncf.ParameterNotSupportedError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def calculate_float_quantization_params(
weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig
) -> Tensor:
"""
Calculates the scale for nf4 or mxfp4/mxfp8_e4m3 quantization.
Calculates the scale for nf4 or mxfp8_e4m3/mxfp4/fp8_e4m3/fp4 quantization.

:param weight: Weight array to compress.
:param reduction_axes: Axes along which to reduce (collect) different statistics (e.g., min, max).
Expand All @@ -99,6 +99,8 @@ def calculate_float_quantization_params(
FP_MAX_VALS = {
CompressWeightsMode.MXFP4: 6.0,
CompressWeightsMode.MXFP8_E4M3: 448.0,
CompressWeightsMode.FP4: 6.0,
CompressWeightsMode.FP8_E4M3: 448.0,
}
if config.mode in [CompressWeightsMode.CODEBOOK, CompressWeightsMode.CB4_F8E4M3] + list(FP_MAX_VALS.keys()):
if config.mode in FP_MAX_VALS:
Expand Down Expand Up @@ -148,17 +150,16 @@ def do_float_quantization(
) -> tuple[Tensor, Tensor, Tensor]:
"""
Computes quantization scale if not provided,
and performs corresponding (nf4, MXFP4 and MXFP8_E4M3) weight quantization.
and performs corresponding (nf4, MXFP4, MXFP8_E4M3, FP4, FP8_E4M3) weight quantization.
NF4 format uses 16 levels in [-1, 1] range, while MXFP4 uses 16 levels in [-6, 6].
For MXFP8_E4M3 and CODEBOOK currently returns normalized weight without quantization.
For CODEBOOK currently returns normalized weight without quantization.
For MXFP8_E4M3, FP8_E4M3, FP4 and CODEBOOK currently returns normalized weight without quantization.

:param weight: Weight array to compress.
:param config: Weight compression configuration.
:param reduction_axes: Axes, along which to reduce (collect) different statistics.
:param precomputed_scale: Optional precomputed scale.
:return: Returns quantized (for MXFP8_E4M3 and codebook normalized) weight tensor and corresponding scale tensor and
optional indexes for codebook.
:return: Returns quantized (for MXFP8_E4M3, FP4, FP8_E4M3 and codebook normalized)
weight tensor and corresponding scale tensor and optional indexes for codebook.
"""
assert not config.is_integer

Expand Down Expand Up @@ -208,7 +209,7 @@ def float_quantize_dequantize_weight(
) -> Union[Tensor, tuple[Tensor, Tensor, Tensor]]:
"""
First quantizes the given weight tensor to float dtype and then dequantizes it back to obtain float32 values.
MXFP8_E4M3 mode is currently not supported.
MXFP8_E4M3, FP8_E4M3 and FP4 modes currently are not supported.

:param weight: The weight tensor to quantize-dequantize.
:param config: Compression configuration.
Expand Down
56 changes: 39 additions & 17 deletions src/nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,8 @@ def compress_weights(
MXFP4 is MX-compliant FP4 format with E2M1 values sharing group-level E8M0 scale. The size of group is 32.
MXFP8_E4M3 - is MX-compliant FP8 format with E4M3 values sharing a group-level E8M0 scale.
The size of group is 32.
FP8_E4M3 - is FP8 format with E4M3 values sharing a group-level FP16 scale.
FP4 - is FP4 format with E2M1 values sharing a group-level FP16 scale.
:type mode: nncf.CompressWeightsMode
:param ratio: the ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to NF4
and the rest to INT8_ASYM).
Expand Down Expand Up @@ -517,14 +519,19 @@ def compress_weights(
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl

if mode in [
not_supported_modes = [
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
CompressWeightsMode.CODEBOOK,
CompressWeightsMode.CB4_F8E4M3,
]:
msg = "Torch backend does not support NF4, MXFP4, MXFP8_E4M3 and CODEBOOK modes for weight compression."
]
if mode in not_supported_modes:
msg = (
f"Torch backend does not support {[m.value for m in not_supported_modes]} modes for weight compression."
)
raise nncf.ParameterNotSupportedError(msg)

options = {"gptq": gptq, "lora_correction": lora_correction}
Expand Down Expand Up @@ -567,14 +574,19 @@ def compress_weights(
compress_weights_impl as fx_compression_weights_impl,
)

if mode in [
not_supported_modes = [
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
CompressWeightsMode.CODEBOOK,
CompressWeightsMode.CB4_F8E4M3,
]:
msg = "Torch backend does not support NF4, MXFP4, MXFP8_E4M3 and CODEBOOK modes for weight compression."
]
if mode in not_supported_modes:
msg = (
f"Torch backend does not support {[m.value for m in not_supported_modes]} modes for weight compression."
)
raise nncf.ParameterNotSupportedError(msg)

options = {
Expand Down Expand Up @@ -610,14 +622,19 @@ def compress_weights(
msg = "Scale estimation, GPTQ or Lora Correction algorithm is defined, but dataset is None."
raise nncf.ParameterNotSupportedError(msg)

if any((awq, scale_estimation, gptq, lora_correction)) and mode in [
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
]:
msg = (
"AWQ, Scale estimation, GPTQ or Lora Correction algorithm is defined, but mode in [MXFP4, MXFP8_E4M3]."
)
raise nncf.ParameterNotSupportedError(msg)
if any((awq, scale_estimation, gptq, lora_correction)):
not_supported_modes = [
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
]
if mode in not_supported_modes:
msg = (
"AWQ, Scale estimation, GPTQ or Lora Correction algorithm is defined,"
f" but mode in {[m.value for m in not_supported_modes]}."
)
raise nncf.ParameterNotSupportedError(msg)

if gptq and lora_correction:
msg = "Simultaneous use of Lora correction and GPTQ algorithms is not supported. Select one of them."
Expand All @@ -632,14 +649,19 @@ def compress_weights(
elif backend == BackendType.ONNX:
from nncf.onnx.quantization.quantize_model import compress_weights_impl as onnx_compress_weights_impl

if mode in [
not_supported_modes = [
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
CompressWeightsMode.CODEBOOK,
CompressWeightsMode.CB4_F8E4M3,
]:
msg = "ONNX backend does not support NF4, MXFP4, MXFP8_E4M3 and CODEBOOK modes for weight compression."
]
if mode in not_supported_modes:
msg = (
f"ONNX backend does not support {[m.value for m in not_supported_modes]} modes for weight compression."
)
raise nncf.ParameterNotSupportedError(msg)

options = {
Expand Down
67 changes: 66 additions & 1 deletion tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def test_call_gptq_with_dataset_scale_estimation_neg_group_size(mode):
(CompressWeightsMode.MXFP4, ov.Type.f4e2m1),
],
)
def test_mixed_precision_fp(sensitivity_metric, all_layers, ratio, ref_ids, mode, ov_type, group_size):
def test_mixed_precision_mxfp(sensitivity_metric, all_layers, ratio, ref_ids, mode, ov_type, group_size):
# Use hidden dim % 32 == 0 to make it possible to quantize in MX format
model = SequentialMatmulModel(mm_hidden_dim=32).ov_model
dataset = Dataset([np.ones([1, 4, 32]), np.arange(128).reshape(1, 4, 32)])
Expand Down Expand Up @@ -1179,6 +1179,71 @@ def test_mixed_precision_fp(sensitivity_metric, all_layers, ratio, ref_ids, mode
assert ref_e8m0_nodes == names_e8m0


@pytest.mark.parametrize(
("sensitivity_metric", "all_layers", "ratio", "ref_ids", "group_size"),
(
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, True, 1, [0, 1, 2, 3, 4], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, True, 0.8, [0, 1, 2], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, True, 0.4, [0], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, True, 0.2, [], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, False, 1, [0, 1, 2, 3], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, False, 0.8, [0, 1, 2], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, False, 0.4, [0], None),
(SensitivityMetric.WEIGHT_QUANTIZATION_ERROR, False, 0.2, [], None),
(SensitivityMetric.HESSIAN_INPUT_ACTIVATION, True, 0.8, [0, 1, 2], None),
(SensitivityMetric.HESSIAN_INPUT_ACTIVATION, False, 0.8, [0, 1, 2], None),
(SensitivityMetric.MEAN_ACTIVATION_VARIANCE, True, 0.8, [0, 1, 2], None),
(SensitivityMetric.MEAN_ACTIVATION_VARIANCE, False, 0.8, [0, 1, 2], None),
(SensitivityMetric.MAX_ACTIVATION_VARIANCE, True, 0.8, [0, 1, 2], None),
(SensitivityMetric.MAX_ACTIVATION_VARIANCE, False, 0.8, [0, 1, 2], None),
(SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE, True, 0.8, [0, 1, 2], None),
(SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE, False, 0.8, [0, 1, 2], None),
# One test to check manual group size setup is working as expected
(SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE, False, 0.8, [0, 1, 2], 128),
),
)
@pytest.mark.parametrize(
"mode, ov_type",
[
(CompressWeightsMode.FP8_E4M3, ov.Type.f8e4m3),
(CompressWeightsMode.FP4, ov.Type.f4e2m1),
],
)
def test_mixed_precision_fp(sensitivity_metric, all_layers, ratio, ref_ids, mode, ov_type, group_size):
model = SequentialMatmulModel(mm_hidden_dim=128).ov_model
dataset = Dataset([np.ones([1, 4, 128]), np.arange(512).reshape(1, 4, 128)])
kwargs = {}
if group_size is not None:
kwargs["group_size"] = group_size
compressed_model = compress_weights(
model,
mode=mode,
ratio=ratio,
all_layers=all_layers,
sensitivity_metric=sensitivity_metric,
dataset=dataset,
**kwargs,
)
ops = []
for op in compressed_model.get_ordered_ops():
if op.get_element_type() == ov_type:
# Check effective default group size == 128
assert tuple(op.shape) == (128, 1, 128)
ops.append(op)

names_fp = {op.get_friendly_name() for op in ops}
ref_fp_nodes = {f"weights_{i}" for i in ref_ids}
assert ref_fp_nodes == names_fp

names_scales = {
op.get_friendly_name()
for op in compressed_model.get_ordered_ops()
if op.get_element_type() == ov.Type.f16 and "scale" in op.get_friendly_name()
}
ref_scale_nodes = {f"weights_{i}/scale" for i in range(5)}
assert ref_scale_nodes == names_scales


@pytest.mark.parametrize(
("mode", "all_layers", "ratio", "ref_ids"),
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
CompressWeightsMode.NF4,
CompressWeightsMode.MXFP4,
CompressWeightsMode.MXFP8_E4M3,
CompressWeightsMode.FP8_E4M3,
CompressWeightsMode.FP4,
)


Expand Down