Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/post_training_quantization/onnx/mobilenet_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,16 @@ def run_benchmark(path_to_model: Path, shape: list[int]) -> float:
return float(match.group(1))


def get_model_size(path: Path, m_type: str = "Mb") -> float:
model_size = path.stat().st_size
for t in ["bytes", "Kb", "Mb"]:
if m_type == t:
break
model_size /= 1024
print(f"Model size: {model_size:.3f} Mb")
return model_size


###############################################################################
# Create an ONNX model and dataset

Expand Down Expand Up @@ -134,10 +144,12 @@ def transform_fn(data_item):
fp32_model_path = ROOT / "mobilenet_v2_fp32.onnx"
onnx.save(model, fp32_model_path)
print(f"[1/7] Save FP32 model: {fp32_model_path}")
fp32_model_size = get_model_size(fp32_model_path)

int8_model_path = ROOT / "mobilenet_v2_int8.onnx"
onnx.save(onnx_quantized_model, int8_model_path)
print(f"[2/7] Save INT8 model: {int8_model_path}")
int8_model_size = get_model_size(int8_model_path)

print("[3/7] Benchmark FP32 model:")
fp32_fps = run_benchmark(fp32_model_path, shape=[1, 3, 224, 224])
Expand All @@ -154,5 +166,6 @@ def transform_fn(data_item):

print("[7/7] Report:")
print(f"Accuracy drop: {fp32_top1 - int8_top1:.3f}")
print(f"Model compression rate: {fp32_model_size / int8_model_size:.3f}")
# https://docs.openvino.ai/latest/openvino_docs_optimization_guide_dldt_optimization_guide.html
print(f"Performance speed up (throughput mode): {int8_fps / fp32_fps:.3f}")
27 changes: 26 additions & 1 deletion src/nncf/onnx/graph/onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import Iterator, Optional, Union
from typing import Any, Iterator, Optional, Union

import numpy as np
import onnx
Expand Down Expand Up @@ -373,3 +373,28 @@ def pack_int4_to_uint8(weight: np.ndarray, block_size: int, signed: bool) -> np.
packed_weight = packed.transpose(2, 0, 1)

return packed_weight


def get_node_attr_value(node: onnx.NodeProto, attr_name: str) -> Optional[Any]:
"""
Retrieves the value of a specified attribute from a node.

This function searches for an attribute with the given name in the provided
node. If the attribute exists, its value is returned. If the attribute is
not found, `None` is returned. If multiple attributes with the same name are
found, a `ValueError` is raised.

:param node: The node to retrieve the attribute from.
:param attr_name: The name of the attribute to retrieve.
:return: The value of the attribute if found; otherwise, `None`.
"""
matching = [x for x in node.attribute if x.name == attr_name]

if len(matching) > 1:
msg = f"Node has multiple attributes with name {attr_name}."
raise ValueError(msg)

if len(matching) < 1:
return None

return onnx.helper.get_attribute_value(matching[0])
60 changes: 60 additions & 0 deletions src/nncf/onnx/graph/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# limitations under the License.

import onnx
from onnx.reference.ops import load_op

from nncf.onnx.graph.onnx_helper import get_children
from nncf.onnx.graph.onnx_helper import get_children_node_mapping
from nncf.onnx.graph.onnx_helper import get_node_attr_value


def eliminate_nop_cast(model: onnx.ModelProto) -> onnx.ModelProto:
Expand Down Expand Up @@ -77,3 +79,61 @@ def apply_preprocess_passes(model: onnx.ModelProto) -> onnx.ModelProto:
# Otherwise, not all no-op Cast nodes will be found.
preprocessed_model = eliminate_nop_cast(preprocessed_model)
return preprocessed_model


def compress_quantize_weights_transformation(model: onnx.ModelProto):
"""
Transforms the model by folding `QuantizeLinear` nodes with constant inputs
into precomputed, quantized initializers.

This transformation finds `QuantizeLinear` nodes with constant inputs
(i.e., inputs present in the model's initializers), precomputes their quantized values,
updates the initializer with these results, and removes the corresponding
`QuantizeLinear` nodes from the graph.

:param model: The model to be transformed.
"""
initializer = {x.name: x for x in model.graph.initializer}
nodes_to_remove = []

version = max(model.opset_import[0].version, 19)
QuantizeLinear = load_op("", "QuantizeLinear", version)

for node in model.graph.node:
if node.op_type != "QuantizeLinear":
continue

x_name, y_scale_name = node.input[:2]
# `y_zero_point` is an optional input for the `QuantizeLinear` operation.
y_zero_point_name = node.input[2] if len(node.input) > 2 else None

if x_name not in initializer:
continue

nodes_to_remove.append(node)

# Quantize
x = onnx.numpy_helper.to_array(initializer[x_name])
y_scale = onnx.numpy_helper.to_array(initializer[y_scale_name])

y_zero_point = None
if y_zero_point_name:
y_zero_point = onnx.numpy_helper.to_array(initializer[y_zero_point_name])

axis = get_node_attr_value(node, "axis")
if version < 21:
# onnx.reference.ops.op_quantize_linear.QuantizeLinear_19
y = QuantizeLinear.eval(x, y_scale, y_zero_point, axis=axis)
else:
# onnx.reference.ops.op_quantize_linear.QuantizeLinear_21
block_size = get_node_attr_value(node, "block_size")
y = QuantizeLinear.eval(x, y_scale, y_zero_point, axis=axis, block_size=block_size)

# Update an existing initializer. The new name is the name of the `QuantizeLinear` output.
tensor_proto = onnx.numpy_helper.from_array(y, name=node.output[0])
initializer[x_name].CopyFrom(tensor_proto)

# `QuantizeLinear` and `DequantizeLinear` nodes share initializers on ports 1 and 2,
# so these initializers should not be removed.
for x in nodes_to_remove:
model.graph.node.remove(x)
17 changes: 17 additions & 0 deletions src/nncf/onnx/quantization/backend_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,28 @@ class BackendParameters:
"""
:param EXTERNAL_DATA_DIR: An absolute path to the directory where the external data
files are stored. All external data files must be located in the same folder.
:param COMPRESS_WEIGHTS: If `True` compresses constant quantized weights by folding
`QuantizeLinear` nodes into pre-quantized initializers. If `False`, this transformation
is skipped.
"""

COMPRESS_WEIGHTS = "compress_weights"
EXTERNAL_DATA_DIR = "external_data_dir"


def is_weight_compression_needed(advanced_parameters: Optional[AdvancedQuantizationParameters]) -> bool:
"""
Determines whether weight compression is needed based on the provided
advanced quantization parameters.

:param advanced_parameters: Advanced quantization parameters.
:return: `True` if weight compression is needed, `False` otherwise.
"""
if advanced_parameters is not None and advanced_parameters.backend_params is not None:
return advanced_parameters.backend_params.get(BackendParameters.COMPRESS_WEIGHTS, True)
return True


def get_external_data_dir(
advanced_parameters: Optional[Union[AdvancedQuantizationParameters, AdvancedCompressionParameters]],
) -> Optional[str]:
Expand Down
21 changes: 18 additions & 3 deletions src/nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import sys
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

Expand All @@ -30,7 +31,10 @@
from nncf.onnx.graph.model_metadata import set_metadata
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.onnx.graph.passes import apply_preprocess_passes
from nncf.onnx.graph.passes import compress_quantize_weights_transformation
from nncf.onnx.quantization.backend_parameters import BackendParameters
from nncf.onnx.quantization.backend_parameters import get_external_data_dir
from nncf.onnx.quantization.backend_parameters import is_weight_compression_needed
from nncf.parameters import BackupMode
from nncf.parameters import CompressionFormat
from nncf.parameters import CompressWeightsMode
Expand Down Expand Up @@ -177,6 +181,9 @@ def quantize_impl(
remove_metadata(quantized_model, MetadataKey.EXTERNAL_DATA_DIR)
load_external_data_for_model(quantized_model, external_data_dir)

if is_weight_compression_needed(advanced_parameters):
compress_quantize_weights_transformation(quantized_model)

return quantized_model


Expand All @@ -202,8 +209,13 @@ def quantize_with_accuracy_control_impl(
if advanced_accuracy_restorer_parameters is None:
advanced_accuracy_restorer_parameters = AdvancedAccuracyRestorerParameters()

compress_weights = is_weight_compression_needed(advanced_quantization_parameters)

if advanced_quantization_parameters is None:
advanced_quantization_parameters = AdvancedQuantizationParameters()
copied_parameters = AdvancedQuantizationParameters()
else:
copied_parameters = deepcopy(advanced_quantization_parameters)
copied_parameters.backend_params[BackendParameters.COMPRESS_WEIGHTS] = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we update this parameter here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to disable COMPRESS_WEIGHT here to properly remove Quantize-Dequantize pairs during the quantize_with_accuracy_control() pipeline. For reference, we do the same for the OpenVINO backend.

copied_parameters.backend_params[BackendParameters.COMPRESS_WEIGHTS] = False


quantized_model = quantize_impl(
model=model,
Expand All @@ -214,7 +226,7 @@ def quantize_with_accuracy_control_impl(
fast_bias_correction=fast_bias_correction,
model_type=model_type,
ignored_scope=ignored_scope,
advanced_parameters=advanced_quantization_parameters,
advanced_parameters=copied_parameters,
)

if advanced_accuracy_restorer_parameters.intermediate_model_dir:
Expand Down Expand Up @@ -254,7 +266,7 @@ def quantize_with_accuracy_control_impl(
fast_bias_correction,
model_type,
ignored_scope,
advanced_quantization_parameters,
copied_parameters,
)
tuned_quantized_metric_results = evaluator.collect_metric_results(
tuned_quantized_model, validation_dataset, model_name="tuned"
Expand Down Expand Up @@ -292,6 +304,9 @@ def quantize_with_accuracy_control_impl(
evaluator,
)

if compress_weights:
compress_quantize_weights_transformation(quantized_model)

return quantized_model


Expand Down
5 changes: 5 additions & 0 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
"fp32_fps": 1732.8,
"int8_fps": 6013.68,
"performance_speed_up": 3.470498614958449
},
"model_size_metrics": {
"fp32_model_size": 9.01850414276123,
"int8_model_size": 2.9374523162841797,
"model_compression_rate": 3.0701789073360906
}
},
"post_training_quantization_openvino_mobilenet_v2_quantize": {
Expand Down
46 changes: 46 additions & 0 deletions tests/onnx/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import onnx

import nncf
from nncf.onnx.graph.passes import apply_preprocess_passes
from nncf.onnx.graph.passes import compress_quantize_weights_transformation
from nncf.onnx.quantization.backend_parameters import BackendParameters
from tests.onnx.common import ModelBuilder
from tests.onnx.models import build_matmul_model_with_nop_cast


Expand All @@ -21,3 +27,43 @@ def test_apply_preprocess_passes():

assert set(after_nodes) - set(before_nodes) == set()
assert set(before_nodes) - set(after_nodes) == set(["cast"])


def _build_model():
w = np.array([[0.1, 0.3, 0.2, -0.1], [-0.9, 0.1, 0.5, -0.3], [0.0, -0.1, -0.4, -0.9]], dtype=np.float32)

b = np.array([0.1, 0.1, 0.1, 0.1], dtype=np.float32)

mb = ModelBuilder()
x = mb.add_input("X", (2, 3))
x = mb.add_gemm(x, w.shape, weight_data=w, bias_data=b)
mb.add_output(x, (2, 4))
return mb.build(opset_version=19, ir_version=9)


def check_operation_count(model: onnx.ModelProto, op_type_to_count: dict[str, int]):
count = {}
for node in model.graph.node:
if node.op_type in op_type_to_count:
count[node.op_type] = count.get(node.op_type, 0) + 1
assert count == op_type_to_count


def test_compress_quantize_weights_transformation():
model = _build_model()

x = np.array([[0.2, -0.1, 0.9], [-0.1, -0.9, 0.5]], dtype=np.float32)

# Prepare quantized model without weight compression
calibration_dataset = nncf.Dataset([{"X": x}])
quantized_model = nncf.quantize(
model,
calibration_dataset,
advanced_parameters=nncf.AdvancedQuantizationParameters(
backend_params={BackendParameters.COMPRESS_WEIGHTS: False}
),
)

check_operation_count(quantized_model, {"QuantizeLinear": 2, "DequantizeLinear": 2})
compress_quantize_weights_transformation(quantized_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is covering the transformation, but I would suggest to additionally test the nncf.quantize with COMPRESS_WEIGHTS: True to check that the API is working as expected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is necessary here. We already have an end-to-end test (tests/cross_fw/examples/test_examples.py[post_training_quantization_onnx_mobilenet_v2] where we compare the model's compression rate with the reference. So we'll be able to catch the error there if COMPRESS_WEIGHTS: True doesn't work as expected.

check_operation_count(quantized_model, {"QuantizeLinear": 1, "DequantizeLinear": 2})
15 changes: 15 additions & 0 deletions tests/post_training/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ class RunInfo:
time_compression: Optional[float] = None
num_compress_nodes: Optional[NumCompressNodes] = None
stats_from_output = StatsFromOutput()
fp32_model_size: Optional[float] = None
int8_model_size: Optional[float] = None

@staticmethod
def format_time(time_elapsed):
Expand Down Expand Up @@ -212,6 +214,9 @@ def get_result_dict(self) -> dict[str, str]:
"Metric name": self.metric_name,
"Metric value": self.metric_value,
"Metric diff": self.metric_diff,
"Model size Mb (FP32)": self.fp32_model_size,
"Model size Mb (INT8)": self.int8_model_size,
"Compression rate:": self.fp32_model_size / self.int8_model_size,
**self.num_compress_nodes.get_data(),
"Compr. time": self.format_time(self.time_compression),
**self.stats_from_output.get_stats(),
Expand All @@ -224,6 +229,15 @@ def get_result_dict(self) -> dict[str, str]:
return result


def get_model_size(path: Path, m_type: str = "Mb") -> float:
model_size = path.stat().st_size
for t in ["bytes", "Kb", "Mb"]:
if m_type == t:
break
model_size /= 1024
return model_size


class BaseTestPipeline(ABC):
"""
Base class to test compression algorithms.
Expand Down Expand Up @@ -534,6 +548,7 @@ def save_compressed_model(self) -> None:
elif self.backend == BackendType.ONNX:
onnx_path = self.output_model_dir / "model.onnx"
onnx.save(self.compressed_model, str(onnx_path))
self.run_info.int8_model_size = get_model_size(onnx_path)
ov_model = ov.convert_model(onnx_path)
ov.serialize(ov_model, self.path_compressed_ir)
elif self.backend in OV_BACKENDS:
Expand Down
2 changes: 2 additions & 0 deletions tests/post_training/pipelines/image_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tests.post_training.pipelines.base import OV_BACKENDS
from tests.post_training.pipelines.base import PT_BACKENDS
from tests.post_training.pipelines.base import BackendType
from tests.post_training.pipelines.base import get_model_size
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase

# Disable using aten::scaled_dot_product_attention
Expand Down Expand Up @@ -50,6 +51,7 @@ def prepare_model(self) -> None:
torch.onnx.export(
timm_model, self.dummy_tensor, onnx_path, export_params=True, opset_version=13, **additional_kwargs
)
self.run_info.fp32_model_size = get_model_size(onnx_path)
self.model = onnx.load(onnx_path)
self.input_name = self.model.graph.input[0].name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tests.post_training.pipelines.base import FX_BACKENDS
from tests.post_training.pipelines.base import PT_BACKENDS
from tests.post_training.pipelines.base import BackendType
from tests.post_training.pipelines.base import get_model_size
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase


Expand Down Expand Up @@ -97,6 +98,7 @@ def prepare_model(self) -> None:
torch.onnx.export(
model, self.dummy_tensor, onnx_path, export_params=True, opset_version=13, **additional_kwargs
)
self.run_info.fp32_model_size = get_model_size(onnx_path)
self.model = onnx.load(onnx_path)
self.input_name = self.model.graph.input[0].name

Expand Down