Skip to content

Commit

Permalink
[experimental][collectors] Redundant code is removed (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#3006)

### Changes

* `get_inplace_fn` is not abstract method anymore
* Redundant code is removed

### Reason for changes

* To remove redundant code

### Related tickets



### Tests
  • Loading branch information
daniil-lyakhov authored Oct 14, 2024
1 parent 989f60b commit 8ef38ec
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 189 deletions.
36 changes: 1 addition & 35 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ def _reduce_out_of_place(self, x: List[TensorType]) -> List[TensorType]:
:param x: Tensor to register.
"""

@abstractmethod
def get_inplace_fn(self) -> Optional[InplaceInsertionFNType]:
"""
Returns correspondent inplace operation builder if inplace operations are available in backend.
:return: Inplace operation builder if possible else None.
"""
return None

def __call__(self, x: List[Tensor]):
if any(t.isempty() for t in x):
Expand Down Expand Up @@ -260,19 +260,6 @@ def register_statistic_branch(
self._aggregators[key] = aggregator
self._stat_container_kwargs_map[container_key] = key

def get_output_info(self, target_node_name: str, port_id: int) -> List[Tuple[int, List[str]]]:
"""
Returns list of pairs of reducers names and correspondent output names.
:param target_node_name: Target node name to assemble output name.
:param port_id: Target node specific port id to assemble output name.
:returns: List of pairs of reducers hashes and correspondent output names.
"""
retval = []
for reducer in self._reducers:
retval.append((hash(reducer), reducer.get_output_names(target_node_name, port_id)))
return retval

def register_inputs(self, inputs: Dict[int, List[Tensor]]) -> None:
"""
Registers given input in TensorCollector.
Expand Down Expand Up @@ -333,27 +320,6 @@ def get_statistics(self) -> Union[TensorStatistic, Dict[str, Any]]:
return kwargs
return self._build_statistic_container(self._stat_container, kwargs)

def get_inplace_fn_info(self) -> List[Tuple[Any, int]]:
"""
Returns necessary information to insert inplace operation into graph.
:returns: necessary information to insert inplace operation into graph
in format of pair of reducer builder and correspondent reducer output port id.
"""
retval = []
for reducer in self._reducers:
if reducer.inplace:
retval.append((reducer.get_inplace_fn(), reducer.output_port_id))
return retval

def any_stat_out_of_place(self) -> bool:
"""
Returns True if any reducer is calculated out of place.
:returns: True if any reducer is calculated out of place.
"""
return any(not reducer.inplace for reducer in self._reducers)

def replace_aggregator(self, key: Tuple[int, int, int], aggregator: AggregatorBase) -> None:
"""
Friend method that replaces aggregator instance on equivalent one.
Expand Down
54 changes: 8 additions & 46 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,49 +25,11 @@
from nncf.experimental.common.tensor_statistics.collectors import RawReducer
from nncf.experimental.common.tensor_statistics.collectors import ShapeAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic
from nncf.quantization.advanced_parameters import StatisticsType


class ONNXBasicReducer(TensorReducerBase):
def get_inplace_fn(self):
raise NotImplementedError("ONNX backend has no support of inplace statistics yet.")


class ONNXMinReducer(ONNXBasicReducer, MinReducer):
pass


class ONNXMaxReducer(ONNXBasicReducer, MaxReducer):
pass


class ONNXAbsMaxReducer(ONNXBasicReducer, AbsMaxReducer):
pass


class ONNXMeanReducer(ONNXBasicReducer, MeanReducer):
pass


class ONNXQuantileReducer(ONNXBasicReducer, QuantileReducer):
pass


class ONNXAbsQuantileReducer(ONNXBasicReducer, AbsQuantileReducer):
pass


class ONNXBatchMeanReducer(ONNXBasicReducer, BatchMeanReducer):
pass


class ONNXMeanPerChanelReducer(ONNXBasicReducer, MeanPerChReducer):
pass


def get_mean_statistic_collector(
num_samples: int, channel_axis: int, window_size: Optional[int] = None, inplace: bool = True
) -> TensorCollector:
Expand All @@ -83,9 +45,9 @@ def get_mean_statistic_collector(
"""
inplace = False
if channel_axis == 0:
reducer = ONNXBatchMeanReducer(inplace)
reducer = BatchMeanReducer(inplace)
else:
reducer = ONNXMeanPerChanelReducer(channel_axis=channel_axis, inplace=inplace)
reducer = MeanPerChReducer(channel_axis=channel_axis, inplace=inplace)
noop_reducer = NoopReducer()

kwargs = {
Expand Down Expand Up @@ -118,10 +80,10 @@ def get_raw_stat_collector(num_samples: int) -> TensorCollector:


ONNX_REDUCERS_MAP = {
StatisticsType.MIN: ONNXMinReducer,
StatisticsType.MAX: ONNXMaxReducer,
StatisticsType.ABS_MAX: ONNXAbsMaxReducer,
StatisticsType.MEAN: ONNXMeanReducer,
StatisticsType.QUANTILE: ONNXQuantileReducer,
StatisticsType.ABS_QUANTILE: ONNXAbsQuantileReducer,
StatisticsType.MIN: MinReducer,
StatisticsType.MAX: MaxReducer,
StatisticsType.ABS_MAX: AbsMaxReducer,
StatisticsType.MEAN: MeanReducer,
StatisticsType.QUANTILE: QuantileReducer,
StatisticsType.ABS_QUANTILE: AbsQuantileReducer,
}
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand
Expand All @@ -37,7 +38,6 @@
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer


@COMPRESSION_MODULES.register()
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_abs_max_channel_collector(
num_samples: int, stats_reduction_axes: Tuple[int], inplace: bool, branch_key: str
) -> TensorCollector:
collector = TensorCollector()
reducer = PTAbsMaxReducer(reduction_axes=stats_reduction_axes)
reducer = AbsMaxReducer(reduction_axes=stats_reduction_axes)
aggregator = MaxAggregator(num_samples=num_samples)
collector.register_statistic_branch(branch_key, reducer, aggregator)
return collector
Expand Down
4 changes: 2 additions & 2 deletions nncf/quantization/algorithms/smooth_quant/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
Expand All @@ -34,7 +35,6 @@
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.model_graph_manager import get_const_node
from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer

PT_PRE_LAYER_TARGET_TYPE = TargetType.OPERATOR_PRE_HOOK

Expand Down Expand Up @@ -89,7 +89,7 @@ def get_abs_max_channel_collector(
num_samples: int, stats_reduction_axes: Tuple[int], inplace: bool, branch_key: str
) -> TensorCollector:
collector = TensorCollector()
reducer = PTAbsMaxReducer(reduction_axes=stats_reduction_axes)
reducer = AbsMaxReducer(reduction_axes=stats_reduction_axes)
aggregator = MaxAggregator(num_samples=num_samples)
collector.register_statistic_branch(branch_key, reducer, aggregator)
return collector
Expand Down
69 changes: 14 additions & 55 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import List, Optional, Tuple, Type
from typing import Optional, Tuple, Type

import numpy as np

Expand Down Expand Up @@ -42,47 +42,6 @@
from nncf.tensor import Tensor


class PTReducerMixIn:

def get_inplace_fn(self):
return None

def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return []


class PTMinReducer(PTReducerMixIn, MinReducer):
pass


class PTMaxReducer(PTReducerMixIn, MaxReducer):
pass


class PTAbsMaxReducer(PTReducerMixIn, AbsMaxReducer):
pass


class PTMeanReducer(PTReducerMixIn, MeanReducer):
pass


class PTQuantileReducer(PTReducerMixIn, QuantileReducer):
pass


class PTAbsQuantileReducer(PTReducerMixIn, AbsQuantileReducer):
pass


class PTBatchMeanReducer(PTReducerMixIn, BatchMeanReducer):
pass


class PTMeanPerChanelReducer(PTReducerMixIn, MeanPerChReducer):
pass


def _reshape_all(targets: Tuple[Tensor, ...], target_shape: Tuple[int, ...]):
return map(lambda stat: stat.reshape(target_shape), targets)

Expand Down Expand Up @@ -145,11 +104,11 @@ def get_min_max_statistic_collector(
"num_samples": num_samples,
"aggregation_axes": aggregation_axes,
}
min_reducer = PTMinReducer(reduction_axes)
min_reducer = MinReducer(reduction_axes)
min_aggregator = MinAggregator(**aggregator_kwargs)
tensor_collector.register_statistic_branch(MinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator)

max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer
max_reducer_cls = AbsMaxReducer if use_abs_max else MaxReducer
max_reducer = max_reducer_cls(reduction_axes)
max_aggregator = MaxAggregator(**aggregator_kwargs)
tensor_collector.register_statistic_branch(MinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator)
Expand Down Expand Up @@ -181,7 +140,7 @@ def get_mixed_min_max_statistic_collector(
:return: Mixed min max statistic collector.
"""
tensor_collector = TensorCollector(_get_wrapped_min_max_tensor_statistic(target_shape=scale_shape))
min_reducer = PTMinReducer(reduction_axes)
min_reducer = MinReducer(reduction_axes)

kwargs = {
"num_samples": num_samples,
Expand All @@ -192,7 +151,7 @@ def get_mixed_min_max_statistic_collector(
min_aggregator = min_aggregator_cls(**kwargs)
tensor_collector.register_statistic_branch(MinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator)

max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer
max_reducer_cls = AbsMaxReducer if use_abs_max else MaxReducer
max_reducer = max_reducer_cls(reduction_axes)
max_aggregator_cls = MeanAggregator if use_means_of_maxs else MaxAggregator
max_aggregator = max_aggregator_cls(**kwargs)
Expand Down Expand Up @@ -323,7 +282,7 @@ def get_mean_percentile_statistic_collector(
"""
tensor_collector = TensorCollector(_get_wrapped_percentile_tensor_statistic(target_shape=scale_shape))
quantiles_to_collect = np.true_divide(percentiles_to_collect, 100)
reducer = PTQuantileReducer(reduction_axes=reduction_axes, quantile=quantiles_to_collect)
reducer = QuantileReducer(reduction_axes=reduction_axes, quantile=quantiles_to_collect)
for output_port_id, p in enumerate(percentiles_to_collect):
aggregator = MeanAggregator(
aggregation_axes=aggregation_axes,
Expand All @@ -349,9 +308,9 @@ def get_mean_statistic_collector(
:return: Mean statistic collector.
"""
if channel_axis == 0:
reducer = PTBatchMeanReducer()
reducer = BatchMeanReducer()
else:
reducer = PTMeanPerChanelReducer(channel_axis=channel_axis)
reducer = MeanPerChReducer(channel_axis=channel_axis)
noop_reducer = NoopReducer()

kwargs = {
Expand Down Expand Up @@ -383,10 +342,10 @@ def get_raw_stat_collector(num_samples: Optional[int] = None) -> TensorCollector


PT_REDUCERS_MAP = {
StatisticsType.MIN: PTMinReducer,
StatisticsType.MAX: PTMaxReducer,
StatisticsType.ABS_MAX: PTAbsMaxReducer,
StatisticsType.MEAN: PTMeanReducer,
StatisticsType.QUANTILE: PTQuantileReducer,
StatisticsType.ABS_QUANTILE: PTAbsQuantileReducer,
StatisticsType.MIN: MinReducer,
StatisticsType.MAX: MaxReducer,
StatisticsType.ABS_MAX: AbsMaxReducer,
StatisticsType.MEAN: MeanReducer,
StatisticsType.QUANTILE: QuantileReducer,
StatisticsType.ABS_QUANTILE: AbsQuantileReducer,
}
Loading

0 comments on commit 8ef38ec

Please sign in to comment.