Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ] Add support of arbitrary batch size for PTQ #2197

Merged
merged 131 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
1ce23c5
draft
kshpv Oct 13, 2023
8184df6
check on Nones
kshpv Oct 13, 2023
2e3f507
update aggregator with keep_dims=True
kshpv Oct 18, 2023
b5d15cd
typhints
kshpv Oct 18, 2023
8cb2391
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Oct 18, 2023
1034acd
fix OV tests; update collectors
kshpv Oct 19, 2023
8b526c5
fix tests
kshpv Oct 20, 2023
e51bdb8
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Nov 6, 2023
37684bd
add aggregation axes for OV; comment input check
kshpv Nov 7, 2023
18d931d
add test for OV and Torch
kshpv Nov 8, 2023
605a325
add batch_size param to conformance test
kshpv Nov 9, 2023
fb16b99
hardcode for CI run
kshpv Nov 9, 2023
cd60fa3
hardcode batch size = 10 for calibrate.py
kshpv Nov 10, 2023
f3bda28
Merge remote-tracking branch 'remote/develop' into torch_batch_size
kshpv Dec 18, 2023
cc621ab
merge
kshpv Dec 18, 2023
d2a9b00
update aggregator
kshpv Dec 20, 2023
5ffdf10
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Dec 20, 2023
d95be5d
revert unneseccary changes
kshpv Dec 20, 2023
cd68684
add logging; add torch data for OVEngine
kshpv Dec 20, 2023
4a009f3
refactor method get axes
kshpv Dec 21, 2023
c2659b3
fix OV tests
kshpv Dec 21, 2023
3a13f00
fix Torch tests
kshpv Jan 4, 2024
2347170
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 15, 2024
880073b
logic of warning message inside StatisticsAggregator
kshpv Jan 15, 2024
e9062a5
remove _check_input_data_format in OVEngine
kshpv Jan 15, 2024
8770ca4
get_channel_agnostic_reduction_axes to common
kshpv Jan 15, 2024
9556c49
use get_channel_agnostic_reduction_axes for Torch
kshpv Jan 15, 2024
cb90e77
use get_channel_agnostic_reduction_axes for ONNX
kshpv Jan 15, 2024
d9167e5
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
cd10c57
draft
kshpv Jan 17, 2024
16cc9db
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
11b538a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 17, 2024
426ec04
fix test
kshpv Jan 18, 2024
21b0963
align reduction shape and aggregation shape
kshpv Jan 18, 2024
e90ca32
get_channel_agnostic_reduction_axes -> get_reduction_axes
kshpv Jan 18, 2024
f078a78
upd get_reduction_aggregation_axes
kshpv Jan 18, 2024
e4c57cd
upd aggregator
kshpv Jan 18, 2024
d226074
fix OV test
kshpv Jan 18, 2024
7d8ecd4
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 18, 2024
f502de5
fix ONNX test
kshpv Jan 18, 2024
83d03cb
tests
kshpv Jan 18, 2024
fbfe587
fix torch tests
kshpv Jan 18, 2024
0ae6ac4
fix tests
kshpv Jan 18, 2024
496339f
common tests
kshpv Jan 18, 2024
bcce584
add docs
kshpv Jan 18, 2024
e5950e0
comment
kshpv Jan 18, 2024
1d9ac7a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 19, 2024
41f27b5
rollback changes for torch possible impact qat
kshpv Jan 19, 2024
51f3dd9
upd conformance
kshpv Jan 19, 2024
3a8de2f
upd calibrate.py
kshpv Jan 19, 2024
946523d
add get_reduction_aggregation_axes for PTRangeInitCollectorParams
kshpv Jan 19, 2024
1732d70
non returning None for get_reduction_aggregation_axes
kshpv Jan 19, 2024
1e96318
comments
kshpv Jan 19, 2024
03afe91
comments
kshpv Jan 19, 2024
bf792fb
describe comment
kshpv Jan 19, 2024
f98aea2
description x2
kshpv Jan 19, 2024
fbd05f9
description x3
kshpv Jan 19, 2024
e80bab1
apply suggestion
kshpv Jan 23, 2024
9c1648d
comments
kshpv Jan 24, 2024
df8ad03
add default scenario when batch_size=1 or None
kshpv Jan 25, 2024
f4db2bb
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 25, 2024
f4dfd1c
rollback scales changes
kshpv Jan 26, 2024
4a44a1c
fix tests
kshpv Jan 26, 2024
d4bfaca
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 26, 2024
f77f59b
fix OV test
kshpv Jan 26, 2024
43fd729
add warning for model_type=transformer
kshpv Jan 29, 2024
c20f7d3
fix torch test
kshpv Jan 29, 2024
52203f0
fix torch tests
kshpv Jan 29, 2024
9dd02b9
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 29, 2024
48c8426
final fix torch test
kshpv Jan 30, 2024
3fe8a37
comments
kshpv Jan 30, 2024
d228589
comments x2
kshpv Jan 30, 2024
b7de564
comments x3
kshpv Jan 30, 2024
67e4c7d
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 30, 2024
489d603
fix tests after merge
kshpv Jan 30, 2024
3b9fb6f
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 30, 2024
120ee1a
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Jan 31, 2024
1f0cb94
improve test
kshpv Jan 31, 2024
532e8eb
fix test
kshpv Feb 6, 2024
38d71b8
upd fbs method calculations
kshpv Feb 6, 2024
c490362
revert changes with statistics collection
kshpv Feb 7, 2024
b778c0c
updates aggregators, reducers for BC and FBC
kshpv Feb 13, 2024
1a96012
upd torch mean_per_channel
kshpv Feb 14, 2024
2f89913
fix BC
kshpv Feb 14, 2024
f69acbd
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Feb 14, 2024
74594c7
fixes after merge
kshpv Feb 14, 2024
d760caf
Fix BC calculations
kshpv Feb 15, 2024
50ac6b4
revert FBC and BC changes
kshpv Feb 20, 2024
00f7979
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Feb 20, 2024
532ca55
fix merge
kshpv Feb 20, 2024
54f8ca3
fix revert typo
kshpv Feb 20, 2024
7637d4b
fix export of torch model
kshpv Feb 21, 2024
976255f
comments
kshpv Feb 23, 2024
8951e3c
more comments
kshpv Feb 26, 2024
0d72557
make bs=128 for Torch sample
kshpv Feb 26, 2024
0f8a438
fix channel alighnment + comments
kshpv Feb 27, 2024
78d4d6c
comments
kshpv Feb 28, 2024
34c9960
update typehints; revert changes in OV sample and apply to Torch
kshpv Feb 28, 2024
354505a
typo
kshpv Feb 28, 2024
97cb07f
some code improvements
kshpv Feb 28, 2024
2cc8b81
logging
kshpv Feb 28, 2024
e3a3291
remove iterations_number calculation in Aggregator
kshpv Mar 1, 2024
8cb7c60
update tests
kshpv Mar 1, 2024
ae772aa
reaname parameter
kshpv Mar 1, 2024
41c76fe
apply comments
kshpv Mar 1, 2024
321c65a
polishing
kshpv Mar 1, 2024
f19fd71
add test
kshpv Mar 4, 2024
4996333
small fixes
kshpv Mar 4, 2024
5e8bce7
polishing
kshpv Mar 5, 2024
9ba5700
conformance adoption for any batch_size; better logging
kshpv Mar 6, 2024
a0f5fe9
add dynamic_batch_shape option to conformance
kshpv Mar 6, 2024
7562d19
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 6, 2024
ee6c14b
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 7, 2024
deb0b51
polishing test
kshpv Mar 7, 2024
64cbd99
fix calibrate.py
kshpv Mar 7, 2024
6119511
new polishing
kshpv Mar 7, 2024
e010087
remove warnings about bathc_size>1 in aggregator
kshpv Mar 8, 2024
54319cb
add baatch_size logging in quantize_impl()
kshpv Mar 8, 2024
4e90c65
add IF op to batch_size warning metatypes list
kshpv Mar 8, 2024
d04ba75
put logs from minmax to quantize_impl
kshpv Mar 8, 2024
6e54d07
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 8, 2024
6048155
rm typos
kshpv Mar 8, 2024
09783c4
typehints
kshpv Mar 8, 2024
da05b93
revert debug message minmax
kshpv Mar 8, 2024
291110e
typo
kshpv Mar 8, 2024
2676fbd
add model_param is_batch_size_supported to conformance; make all mode…
kshpv Mar 18, 2024
3ae9d28
add example in Readme
kshpv Mar 18, 2024
5efcdb5
comments
kshpv Mar 20, 2024
d8ea324
iterations_number -> stat_subset_size
kshpv Mar 20, 2024
7288924
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 21, 2024
88653aa
Merge remote-tracking branch 'remote/develop' into HEAD
kshpv Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,36 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
"""
if not self.statistic_points:
return

collected_statistics_num = 0
model_transformer = factory.ModelTransformerFactory.create(model)

merged_statistics = self._get_merged_statistic_points(self.statistic_points, model, graph)
transformation_layout = self._get_transformation_layout_extra_outputs(merged_statistics)
model_with_outputs = model_transformer.transform(transformation_layout)
engine = factory.EngineFactory.create(model_with_outputs)

batch_size = self.dataset.get_batch_size()
batch_size = 1 if batch_size is None else batch_size
dataset_length = self.dataset.get_length()
dataset_length = dataset_length * batch_size if dataset_length is not None else dataset_length
total = (
min(dataset_length or self.stat_subset_size, self.stat_subset_size)
if self.stat_subset_size is not None
else None
)
for input_data in track(
islice(self.dataset.get_inference_data(), self.stat_subset_size),
total=total,
description="Statistics collection",
):
outputs = engine.infer(input_data)
processed_outputs = self._process_outputs(outputs)
self._register_statistics(processed_outputs, merged_statistics)
with track(total=total, description="Statistics collection") as pbar:
for input_data in islice(self.dataset.get_inference_data(), self.stat_subset_size):
batch_size_to_collect = (
min(total - collected_statistics_num, batch_size) if total is not None else batch_size
)
sliced_iput = self._get_sliced_data(input_data, batch_size_to_collect)
kshpv marked this conversation as resolved.
Show resolved Hide resolved
outputs = engine.infer(sliced_iput)
processed_outputs = self._process_outputs(outputs)
self._register_statistics(processed_outputs, merged_statistics)
collected_statistics_num += batch_size_to_collect
pbar.progress.update(pbar.task, advance=batch_size_to_collect)
if total and collected_statistics_num == total:
break

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Expand Down Expand Up @@ -134,3 +142,8 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
:param outputs: raw model outputs
:return: processed model outputs in Dict[str, NNCFTensor] format
"""

@staticmethod
@abstractmethod
def _get_sliced_data(inputs: Any, end: int) -> Any:
""" """
6 changes: 6 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def get_length(self) -> Optional[int]:
return self._data_source.__len__()
return None

def get_batch_size(self) -> Optional[int]:
""" """
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self._data_source, "batch_size"):
return self._data_source.batch_size
return None
nikita-malininn marked this conversation as resolved.
Show resolved Hide resolved


class DataProvider(Generic[DataItem, ModelInput]):
def __init__(
Expand Down
61 changes: 14 additions & 47 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from collections import deque
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union

from nncf.common.tensor import TensorType
Expand Down Expand Up @@ -129,16 +128,14 @@ def __init__(
"""
:param tensor_processor: Backend-specific tensor processor.
:param aggregation_axes: Axes along which to operate.
Registered statistics are stacked along zero axis,
axes >=1 correspond to recieved statistic axes shifted left by 1.
kshpv marked this conversation as resolved.
Show resolved Hide resolved
:param num_samples: Maximum number of samples to collect. Aggregator
skips tensor registration if tensor registration was called num_samples times before.
Aggregator never skips registration if num_samples is None.
"""

self._tensor_processor = tensor_processor
self._aggregation_axes = (0,) if aggregation_axes is None else aggregation_axes
self._keepdims = False
self._aggregation_axes = (0,) if aggregation_axes is None else (0, *map(lambda x: x + 1, aggregation_axes))
kshpv marked this conversation as resolved.
Show resolved Hide resolved
self._keepdims = True
self._num_samples = num_samples
self._collected_samples = 0
self._container = []
Expand Down Expand Up @@ -594,20 +591,7 @@ def _aggregate_impl(self):
return self._container.shape


class TensorAggregatorBase(AggregatorBase, ABC):
def __init__(
self,
tensor_processor: NNCFCollectorTensorProcessor,
aggregation_axes: Optional[AggregationAxes] = None,
num_samples: Optional[int] = None,
window_size=None,
):
super().__init__(tensor_processor, aggregation_axes=aggregation_axes, num_samples=num_samples)
self._window_size = window_size
self._container = deque(maxlen=window_size)


class OnlineAggregatorBase(TensorAggregatorBase, ABC):
class OnlineAggregatorBase(AggregatorBase, ABC):
"""
Base class for aggregators which are using aggregation function fn with following property:
fn([x1, x2, x3]) == fn([fn([x1, x2]), x3]) where x1, x2, x3 are samples to aggregate.
Expand All @@ -616,26 +600,14 @@ class OnlineAggregatorBase(TensorAggregatorBase, ABC):
"""

def _register_reduced_input_impl(self, x: NNCFTensor) -> None:
online_aggregation_axes = tuple(dim - 1 for dim in self._aggregation_axes if dim != 0)
if online_aggregation_axes:
reduced = self._aggregation_fn(x, axis=online_aggregation_axes, keepdims=self._keepdims)
else:
reduced = x
if 0 in self._aggregation_axes:
if self._container:
reduced = self._aggregation_fn(
self._tensor_processor.stack([reduced, self._container]), axis=0, keepdims=False
)
self._container = reduced
else:
self._container.append(reduced)
stacked_tensors = self._tensor_processor.stack([x, *self._container])
aggregated = self._aggregation_fn(stacked_tensors, axis=self._aggregation_axes, keepdims=self._keepdims)
squeezed = self._tensor_processor.squeeze(aggregated, 0)
self._container = [squeezed]
kshpv marked this conversation as resolved.
Show resolved Hide resolved

def _aggregate_impl(self) -> NNCFTensor:
if 0 in self._aggregation_axes:
if self._keepdims:
return self._tensor_processor.stack([self._container]).tensor
return self._container.tensor
return self._tensor_processor.stack(self._container).tensor
assert len(self._container) == 1
kshpv marked this conversation as resolved.
Show resolved Hide resolved
return self._container[0].tensor

@abstractmethod
def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keepdims: bool) -> NNCFTensor:
Expand All @@ -652,7 +624,7 @@ def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keep
return self._tensor_processor.reduce_max(stacked_value, axis=axis, keepdims=keepdims)


class OfflineAggregatorBase(TensorAggregatorBase, ABC):
class OfflineAggregatorBase(AggregatorBase, ABC):
"""
Base class for aggregators which are using aggregation function fn which
does not fulfill property fn([x1, x2, x3]) == fn([fn([x1, x2]), x3])
Expand All @@ -665,7 +637,8 @@ def _register_reduced_input_impl(self, x: TensorType) -> None:

def _aggregate_impl(self) -> NNCFTensor:
stacked_val = self._tensor_processor.stack(self._container)
return self._aggregation_fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims).tensor
aggregated = self._aggregation_fn(stacked_val, axis=self._aggregation_axes, keepdims=self._keepdims)
return self._tensor_processor.squeeze(aggregated, 0).tensor

@abstractmethod
def _aggregation_fn(self, stacked_value: NNCFTensor, axis: AggregationAxes, keepdims: bool) -> NNCFTensor:
Expand All @@ -688,12 +661,9 @@ def __init__(
tensor_processor: NNCFCollectorTensorProcessor,
aggregation_axes: Optional[AggregationAxes] = None,
num_samples: Optional[int] = None,
window_size=None,
kshpv marked this conversation as resolved.
Show resolved Hide resolved
quantile: float = 0.01,
):
super().__init__(tensor_processor, aggregation_axes=aggregation_axes, num_samples=num_samples)
self._window_size = window_size
self._container = deque(maxlen=window_size)
self._quantile = quantile

def _aggregate_impl(self) -> NNCFTensor:
Expand Down Expand Up @@ -734,7 +704,7 @@ def _aggregation_fn(
return self._tensor_processor.masked_median(stacked_samples, axis=axis, mask=mask, keepdims=keepdims)


class MedianAbsoluteDeviationAggregator(TensorAggregatorBase):
class MedianAbsoluteDeviationAggregator(AggregatorBase):
def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._container.append(x)

Expand All @@ -759,19 +729,16 @@ def _aggregate_impl(self) -> Dict[str, NNCFTensor]:
}


class PercentileAggregator(TensorAggregatorBase):
class PercentileAggregator(AggregatorBase):
def __init__(
self,
tensor_processor: NNCFCollectorTensorProcessor,
percentiles_to_collect: List[float],
aggregation_axes: Optional[AggregationAxes] = None,
num_samples: Optional[int] = None,
window_size=None,
):
super().__init__(tensor_processor, aggregation_axes=aggregation_axes, num_samples=num_samples)
self._percentiles_to_collect = percentiles_to_collect
self._window_size = window_size
self._container = deque(maxlen=window_size)

def _register_reduced_input_impl(self, x: TensorType) -> None:
return self._container.append(x)
Expand Down
4 changes: 4 additions & 0 deletions nncf/onnx/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,7 @@ def _get_merged_statistic_points(
@staticmethod
def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, ONNXNNCFTensor]:
return {n: ONNXNNCFTensor(v) for n, v in outputs.items()}

@staticmethod
def _get_sliced_data(inputs: Dict[str, np.ndarray], end: int) -> Dict[str, ONNXNNCFTensor]:
return inputs
4 changes: 4 additions & 0 deletions nncf/openvino/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,7 @@ def _get_merged_statistic_points(
@staticmethod
def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, OVNNCFTensor]:
return {n: OVNNCFTensor(v) for n, v in outputs.items()}

@staticmethod
def _get_sliced_data(inputs: Dict[str, np.ndarray], end: int) -> Dict[str, OVNNCFTensor]:
return inputs
7 changes: 1 addition & 6 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,16 +271,12 @@ def get_output_names(self, target_node_name: str, port_id: int) -> List[str]:
return get_reducer_output_node_names(self.name, target_node_name, port_id, self.output_port_id, self.inplace)


def get_mean_statistic_collector(
num_samples: int, channel_axis: int, window_size: Optional[int] = None, inplace: bool = True
) -> TensorCollector:
def get_mean_statistic_collector(num_samples: int, channel_axis: int, inplace: bool = True) -> TensorCollector:
"""
Mean statistic collector builder.

:param num_samples: Maximum number of samples to collect.
:param channel_axis: Channel axis to use during reduction phase.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Aggregates all available collected statistics in case parameter is None.
:param inplace: Whether the mean reducer should be calculated inplace or out of place.
:return: Mean statistic collector.
"""
Expand All @@ -296,7 +292,6 @@ def get_mean_statistic_collector(
kwargs = {
"tensor_processor": OVNNCFCollectorTensorProcessor,
"num_samples": num_samples,
"window_size": window_size,
}
aggregate_mean = MeanAggregator(**kwargs)
aggregate_shape = ShapeAggregator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ def mean_statistic_collector(
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statistic_collector(num_samples, channel_axis, window_size, inplace)
return get_mean_statistic_collector(num_samples, channel_axis, inplace)

@staticmethod
def raw_statistic_collector(inplace: bool, num_samples: int = None) -> TensorCollector:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ def mean_statistic_collector(
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statistic_collector(num_samples, channel_axis, window_size, inplace)
return get_mean_statistic_collector(num_samples, channel_axis, inplace)

@staticmethod
def get_sub_input_output_names(subgraph: ov.Model) -> Tuple[str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def mean_statistic_collector(
channel_axis: int,
inplace: bool,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
) -> TensorCollector:
return get_mean_statistic_collector(num_samples, channel_axis, window_size)
return get_mean_statistic_collector(num_samples, channel_axis)

@staticmethod
def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]:
Expand Down
Loading