Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

[Compression] Add bias correction feature for PTQ quantizer #5603

Merged
merged 17 commits into from
Jun 29, 2023
Merged
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
r'https://docs\.nvidia\.com/deeplearning/',
r'https://cla\.opensource\.microsoft\.com',
r'https://www\.docker\.com/',
r'https://nlp.stanford.edu/projects/glove/',

# remove after 3.0 release
r'https://nni\.readthedocs\.io/en/v2\.10/compression/overview\.html',
Expand Down
46 changes: 46 additions & 0 deletions nni/contrib/compression/base/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
import logging
import inspect
import functools
from typing import Any, Callable, Dict, List, Tuple, Type, Union, Literal

import torch
Expand Down Expand Up @@ -399,6 +400,14 @@ def forward(self, *args, **kwargs):
if len(self.fused_modules) > 0:
params_dict, activation_func_lis = fuse_modules(self, params_dict, *args, **kwargs)

# obtain original output
original_outputs = None
if getattr(self, "is_bias_correction", False) and check_bias(self.module) == 'Tensor':
for target_name, original_param in params_dict.items():
setattr(self.module, target_name, original_param * 1.0)

original_outputs = self.module_forward(*args, **kwargs)

params_dict = self.patch_params(params_dict)
for target_name, patched_param in params_dict.items():
# NOTE: here using copy_ will cause `backward through the graph a second time` error, don't know why.
Expand All @@ -413,12 +422,49 @@ def forward(self, *args, **kwargs):
#fuse activation func
for activation_module in activation_func_lis:
outputs = activation_module._nni_wrapper.module_forward(outputs)
if original_outputs is not None:
original_outputs = activation_module._nni_wrapper.module_forward(original_outputs)

outputs = self.patch_outputs(outputs)

if getattr(self, "is_bias_correction", False) and check_bias(self.module) == 'Tensor':
assert isinstance(original_outputs, torch.Tensor) and isinstance(outputs, torch.Tensor), \
f"Bias correction is only applied to variables with tensor output types, but got {(type(original_outputs), type(outputs))}"
element_num = functools.reduce(lambda x,y: x * y, list(original_outputs.shape[:-1]))
dim_sum = tuple(range(len(original_outputs.shape[:-1])))
bias_correction = torch.sum(original_outputs - outputs, dim=dim_sum)
if not hasattr(self, 'bias_correction'):
setattr(self, 'bias_correction', bias_correction)
else:
self.bias_correction += bias_correction
if not hasattr(self, 'bias_element_num'):
setattr(self, 'bias_element_num', element_num)
else:
self.bias_element_num: int = self.bias_element_num + element_num

torch.cuda.empty_cache()

return outputs

def update_bias(self):
assert hasattr(self, "bias_element_num")
assert check_bias(self.module) == 'Tensor'

bias_correction = getattr(self, "bias_correction", None)
element_num = getattr(self, "bias_element_num", 0)
assert bias_correction is not None
assert element_num > 0

bias_correction /= element_num ## compute mean

if 'bias' in self.quantization_target_spaces:
target_space = self.quantization_target_spaces['bias']
assert target_space.target is not None and \
list(target_space.target.size()) == list(bias_correction.size())
target_space.target.data += bias_correction.detach().clone()
else:
self.module.bias.data += bias_correction.detach().clone()


class IdentityModuleWrapper(ModuleWrapper): # only aviable for batchnorm
'''
Expand Down
2 changes: 1 addition & 1 deletion nni/contrib/compression/quantization/lsq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LsqQuantizer(Quantizer):
A list of dict, each dict configure which module need to be quantized, and how to quantize.
Please refer :doc:`Compression Config Specification </compression/config_list>` for more information.
evaluator
TODO: {evaluator_docstring}
{evaluator_docstring}

Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion nni/contrib/compression/quantization/lsqplus_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LsqPlusQuantizer(Quantizer):
A list of dict, each dict configure which module need to be quantized, and how to quantize.
Please refer :doc:`Compression Config Specification </compression/config_list>` for more information.
evaluator
TODO: {evaluator_docstring}
{evaluator_docstring}

Examples
--------
Expand Down
71 changes: 49 additions & 22 deletions nni/contrib/compression/quantization/ptq_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..base.target_space import QuantizationTargetSpace
from ..utils import Evaluator, _EVALUATOR_DOCSTRING


Expand Down Expand Up @@ -48,10 +49,11 @@ def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: E
...

def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, \
existed_wrappers: Dict[str, ModuleWrapper] | None = None):
existed_wrappers: Dict[str, ModuleWrapper] | None = None, is_bias_correction: bool = False):
super().__init__(model, config_list, evaluator, existed_wrappers)
self.evaluator: Evaluator
self.is_compressed = False
self.is_bias_correction = is_bias_correction
self.register_ptq_apply_method()
self.register_track_func()

Expand All @@ -68,10 +70,9 @@ def register_track_func(self):
for module_name, _ in self._target_spaces.items():
wrapper = self._module_wrappers[module_name]
wrapper.register_track_func(self.track_min_max_val)
# wrapper.register_track_func(self.update_scale_zp_for_bias_correction)

def track_min_max_val(self, wrapper: ModuleWrapper, target_name: str, target: Tensor):
if self.is_compressed:
return
def amin_reduce_func(converted_target: Tensor):
return converted_target.detach().amin(dim=-1)

Expand All @@ -97,25 +98,7 @@ def amax_reduce_func(converted_target: Tensor):
def update_scale_zp(self):
for _, ts in self._target_spaces.items():
for _, target_space in ts.items():
if target_space.tracked_max is None or target_space.tracked_min is None:
return
tracked_min = torch.min(target_space.tracked_min, torch.zeros_like(target_space.tracked_min))
tracked_max = torch.max(target_space.tracked_max, torch.zeros_like(target_space.tracked_max))
zero_point = torch.zeros_like(tracked_min)
if target_space.quant_scheme in ['symmetric', None]:
abs_max = torch.max(torch.abs(tracked_min), torch.abs(tracked_max))
scale = abs_max / (float(target_space.qmax - target_space.qmin) / 2)
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
# NOTE: here need to check, +1 because in pytorch, symmetric qint8 zp is 0, quint8 zp is 128.
zero_point_val = (target_space.qmax + target_space.qmin + 1) // 2
zero_point = torch.full_like(zero_point, zero_point_val)
elif target_space.quant_scheme == 'affine':
scale = (tracked_max - tracked_min) / float(target_space.qmax - target_space.qmin)
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
zero_point = target_space.qmin - torch.round(tracked_min / scale)
else:
raise RuntimeError(f'Unknown quant_scheme {target_space.quant_scheme}')
zero_point = torch.clamp(zero_point, target_space.qmin, target_space.qmax)
scale, zero_point = compute_scale_zp(target_space) # type: ignore
target_space.scale, target_space.zero_point = scale, zero_point

def _single_compress(self, max_steps: int | None, max_epochs: int | None):
Expand All @@ -132,6 +115,50 @@ def _fuse_postprocess(self, evaluator: Evaluator) -> None:
self.update_scale_zp()
self.is_compressed = True
self.register_ptq_apply_method()
# bias correction
if self.is_bias_correction:
self.bias_correction()

def bias_correction(self):
assert self.is_bias_correction, \
f"is_bias_correction should be True, but got {self.is_bias_correction}"
for module_name, _ in self._target_spaces.items():
wrapper = self._module_wrappers[module_name]
setattr(wrapper, "is_bias_correction", self.is_bias_correction)
# running bias correction process
# TODO: add an warning for user to change evaluation dataset
self.evaluator.evaluate()
for module_name, _ in self._target_spaces.items():
wrapper = self._module_wrappers[module_name]
wrapper.update_bias()
delattr(wrapper, "is_bias_correction")
delattr(wrapper, "bias_correction")
delattr(wrapper, "bias_element_num")
self.evaluator.evaluate()
self.update_scale_zp()


def compute_scale_zp(target_space: QuantizationTargetSpace):
if target_space.tracked_max is None or target_space.tracked_min is None:
return
tracked_min = torch.min(target_space.tracked_min, torch.zeros_like(target_space.tracked_min))
tracked_max = torch.max(target_space.tracked_max, torch.zeros_like(target_space.tracked_max))
zero_point = torch.zeros_like(tracked_min)
if target_space.quant_scheme in ['symmetric', None]:
abs_max = torch.max(torch.abs(tracked_min), torch.abs(tracked_max))
scale = abs_max / (float(target_space.qmax - target_space.qmin) / 2)
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
# NOTE: here need to check, +1 because in pytorch, symmetric qint8 zp is 0, quint8 zp is 128.
zero_point_val = (target_space.qmax + target_space.qmin + 1) // 2
zero_point = torch.full_like(zero_point, zero_point_val)
elif target_space.quant_scheme == 'affine':
scale = (tracked_max - tracked_min) / float(target_space.qmax - target_space.qmin)
scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps))
zero_point = target_space.qmin - torch.round(tracked_min / scale)
else:
raise RuntimeError(f'Unknown quant_scheme {target_space.quant_scheme}')
zero_point = torch.clamp(zero_point, target_space.qmin, target_space.qmax)
return scale, zero_point


def update_tracked_value(original_val: Union[Tensor, None], current_val: Tensor, mode: str="max"):
Expand Down