Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] AWQ support for WeightCompression #3279

Open
wants to merge 14 commits into
base: develop
Choose a base branch
from
30 changes: 13 additions & 17 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ def __init__(
criterion_cls = MIXED_PRECISION_CRITERIA.get(self._sensitivity_metric)
self._mixed_precision_algo = criterion_cls(primary_config, self._ratio, self._subset_size)
self._statistics_path = self._advanced_parameters.statistics_path

if self._awq:
awq_params = self._advanced_parameters.awq_params
self.awq_algo = AWQ(
awq_params.subset_size,
awq_params.percent_to_apply,
awq_params.alpha_min,
awq_params.alpha_max,
awq_params.steps,
)
if self._gptq:
gptq_params = self._advanced_parameters.gptq_params
self._gptq_algo = GPTQ(
Expand Down Expand Up @@ -586,26 +596,12 @@ def apply(
nodes_to_compress = list(
filter(lambda node: node.node_name not in nodes_names_to_exclude, nodes_to_compress)
)

if self._awq:
awq_params = self._advanced_parameters.awq_params
awq_algo = AWQ(
model,
self._backend_entity.name_to_node_mapping,
all_weight_params,
nodes_to_compress,
statistics,
awq_params.subset_size,
awq_params.percent_to_apply,
awq_params.alpha_min,
awq_params.alpha_max,
awq_params.steps,
)
awq_algo.apply(model, graph)
self.awq_algo.apply(model, graph, all_weight_params, nodes_to_compress, statistics, self._backend_entity)
# After applying AWQ we need to update statistics since AWQ alters the activations
statistics = awq_algo.update_statistics(statistics)
statistics = self.awq_algo.update_statistics(statistics)
# del is used to prematurely mark non-necessary data as free for garbage collection
del awq_algo
del self.awq_algo

scales = {}
zero_points = {}
Expand Down
68 changes: 32 additions & 36 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TypeVar
from typing import Dict, List, Optional, TypeVar

import nncf
from nncf import Dataset
from nncf import nncf_logger
from nncf.common.factory import ModelTransformerFactory
from nncf.common.graph.graph import NNCFGraph
Expand All @@ -29,6 +28,7 @@
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
Expand Down Expand Up @@ -61,34 +61,20 @@ class AWQ(Algorithm):

def __init__(
self,
model: TModel,
name_to_node_mapping: Dict[str, Any],
all_weight_params: List[WeightCompressionParameters],
nodes_to_compress: List[NNCFNode],
statistics: Dict[str, WCTensorStatistic],
subset_size: int = 32,
percent_to_apply=0.002,
alpha_min=0.0,
alpha_max=1.0,
steps=100,
percent_to_apply: float = 0.002,
alpha_min: float = 0.0,
alpha_max: float = 1.0,
steps: int = 100,
):
"""
:param model: Model for applying algorithm.
:param name_to_node_mapping: Name to node mapping for updating node weights.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:param statistics: Input activation statistics for each node.
:param subset_size: The number of samples for AWQ.
:param percent_to_apply: The percent of outliers for correction.
:param alpha_min: Minimum value of smoothness parameter for grid search.
:param alpha_max: Maximal value of smoothness parameter for grid search.
:param steps: The number of the steps in grid search.
"""
super().__init__()
self.name_to_node_mapping = name_to_node_mapping
self._all_weight_params = all_weight_params
self._nodes_to_compress = nodes_to_compress
self._statistics = statistics
self._subset_size = subset_size
self._percent_to_apply = percent_to_apply
self._alpha_min = alpha_min
Expand All @@ -98,45 +84,55 @@ def __init__(
self._patterns = None
self._scale_per_target_node = {}

self._set_backend_entity(model)

@property
def available_backends(self) -> List[BackendType]:
return [BackendType.OPENVINO]
return [BackendType.OPENVINO, BackendType.TORCH]

def _set_backend_entity(self, model: TModel) -> None:
def _set_backend_entity(
self, model: TModel, wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None
) -> None:
"""
Creates a helper class with a backed-specific logic of the algorithm.

:param model: Backend-specific input model.
:param wc_backend_entity: Weight compression algorithm backend.
"""

model_backend = get_backend(model)
if model_backend == BackendType.OPENVINO:
from nncf.quantization.algorithms.weight_compression.openvino_backend import OVAWQAlgoAlgoBackend

self._backend_entity = OVAWQAlgoAlgoBackend(model, self.name_to_node_mapping)
self._patterns = self._backend_entity.get_awq_patterns()
self._backend_entity = OVAWQAlgoAlgoBackend(model, wc_backend_entity.name_to_node_mapping)
elif model_backend == BackendType.TORCH:
from nncf.quantization.algorithms.weight_compression.torch_backend import PTAWQAlgoAlgoBackend

self._backend_entity = PTAWQAlgoAlgoBackend()

else:
msg = f"Cannot return backend-specific AWQ entity because {model_backend.value} is not supported!"
raise nncf.UnsupportedBackendError(msg)
self._patterns = self._backend_entity.get_awq_patterns()

def apply(
self,
model: TModel,
graph: NNCFGraph,
statistic_points: Optional[StatisticPointsContainer] = None,
dataset: Optional[Dataset] = None,
all_weight_params: List[WeightCompressionParameters],
nodes_to_compress: List[NNCFNode],
statistics: Dict[str, WCTensorStatistic],
wc_backend_entity: Optional[WeightCompressionAlgoBackend] = None,
) -> TModel:
"""
Applies the algorithm to the model.

:param model: Model for applying algorithm.
:param graph: Model graph.
:param statistic_points: Statistic points with collected statistics values.
:param dataset: A representative dataset for the calibration process.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:param statistics: Input activation statistics for each node.
:param wc_backend_entity: Weight compression algorithm backend.
:return: A resulting model.
"""
self._set_backend_entity(model, wc_backend_entity)
matches = []

inference_nncf_graph = transform_to_inference_graph(deepcopy(graph), [], [], [], [])
Expand All @@ -152,7 +148,7 @@ def apply(
model_transformer = ModelTransformerFactory.create(model, inplace=True)

awq_data = {}
name_mapping = {wp.weight_name: idx for idx, wp in enumerate(self._all_weight_params)}
name_mapping = {wp.weight_name: idx for idx, wp in enumerate(all_weight_params)}

for match in matches:
nncf_node = graph.get_node_by_key(match[-1])
Expand All @@ -167,11 +163,11 @@ def apply(
if target_node_names[-1] not in name_mapping:
continue

weight_params = self._all_weight_params[name_mapping[target_node_names[-1]]]
weight_params = all_weight_params[name_mapping[target_node_names[-1]]]

if weight_params.compression_config.num_bits != 4:
continue
target_node = self._nodes_to_compress[name_mapping[target_node_names[-1]]]
target_node = nodes_to_compress[name_mapping[target_node_names[-1]]]

# avoid matching different patterns for the same node
if target_node.node_name in awq_data:
Expand All @@ -183,7 +179,7 @@ def apply(
merge_node_names = []
for weight_op_friendly_name, _ in self._backend_entity.get_weight_names_and_port_ids(nncf_node, graph):
merge_node_names.append(weight_op_friendly_name)
merge_node = self._nodes_to_compress[name_mapping[merge_node_names[-1]]]
merge_node = nodes_to_compress[name_mapping[merge_node_names[-1]]]
else: # pattern Act->MatMul or Act->Multiply->MatMul
merge_node = nncf_node

Expand All @@ -205,7 +201,7 @@ def apply(

config = wp.compression_config

s, X = process_stats(self._statistics[k], self._subset_size)
s, X = process_stats(statistics[k], self._subset_size)

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.patterns import GraphPattern
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
Expand All @@ -35,6 +36,9 @@
from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns
from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend
from nncf.quantization.algorithms.weight_compression.backend import MixedPrecisionAlgoBackend
from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
Expand All @@ -44,6 +48,8 @@
from nncf.tensor.definitions import TensorDataType
from nncf.torch.dynamic_graph.scope import Scope
from nncf.torch.graph import operator_metatypes as om
from nncf.torch.graph.operator_metatypes import PTMulMetatype
from nncf.torch.graph.pattern_operations import ATOMIC_ACTIVATIONS_OPERATIONS
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import find_const_node_in_constant_subgraph
Expand All @@ -52,6 +58,7 @@
from nncf.torch.model_graph_manager import get_module_by_name
from nncf.torch.model_graph_manager import split_const_name
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.model_transformer import update_parameter
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor
from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor
Expand Down Expand Up @@ -202,12 +209,12 @@ def get_weight_shape(node_with_weight: NNCFNode, weight_port_id: int, graph: NNC
def set_weight(
self, node_with_weight: NNCFNode, weight_port_id: int, model: torch.nn.Module, graph: NNCFGraph, weight: Tensor
):
pass
update_parameter(node_with_weight.node_name, "weight", weight.data, model)

def insert_adapters(
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
) -> None:
pass
raise NotImplementedError()

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int, algorithm_key: str) -> Callable[[StatisticPoint], bool]:
Expand Down Expand Up @@ -320,6 +327,39 @@ def transform_model(
return transformed_model


class PTAWQAlgoAlgoBackend(AWQAlgoBackend, PTWeightCompressionAlgoBackend):
@staticmethod
def get_awq_patterns():
return get_awq_patterns(
PTWeightCompressionAlgoBackend.MATMUL_METATYPES,
PTMulMetatype,
ATOMIC_ACTIVATIONS_OPERATIONS[GraphPattern.METATYPE_ATTR],
)

@staticmethod
def scale_insertion_command(
source_node: NNCFNode,
next_nodes,
source_output_port_id: int,
scale: torch.Tensor,
) -> PTSharedFnInsertionCommand:
input_port_id = 0
target_points = []
for node in next_nodes:
target_points.append(
PTTargetPoint(
PTWeightCompressionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[TargetType.PRE_LAYER_OPERATION],
node.node_name,
input_port_id=input_port_id,
)
)

sq_multiply = SQMultiply(scale.shape)
sq_multiply.scale = scale
scale_node_name = f"{source_node.node_name}/awq_mul"
return PTSharedFnInsertionCommand(target_points, sq_multiply, scale_node_name)


class PTMixedPrecisionAlgoBackend(MixedPrecisionAlgoBackend, PTWeightCompressionAlgoBackend):
@staticmethod
def mean_variance_statistic_collector(
Expand Down
6 changes: 1 addition & 5 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,7 @@ def compress_weights(
msg = "Torch backend does not support NF4 and E2M1 modes for weight compression."
raise nncf.ParameterNotSupportedError(msg)

options = {
"awq": awq,
"gptq": gptq,
"lora_correction": lora_correction,
}
options = {"gptq": gptq, "lora_correction": lora_correction}
unsupported_options = [name for name, value in options.items() if value is not None]
if unsupported_options:
msg = f"Torch backend does not support {', '.join(unsupported_options)} option(s). Set them to None."
Expand Down
3 changes: 1 addition & 2 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor
from torch import nn
from torch.nn.parameter import Parameter

Expand Down Expand Up @@ -242,7 +241,7 @@ def _apply_weights_update_transformations(
return model


def update_parameter(target_node_name: str, parameter_name: str, new_value: Tensor, model: NNCFNetwork) -> None:
def update_parameter(target_node_name: str, parameter_name: str, new_value: torch.Tensor, model: NNCFNetwork) -> None:
"""
Update parameter for target module.

Expand Down
Loading