Skip to content

Commit 5e21b70

Browse files
xin3hechangwangsspre-commit-ci[bot]chensuyue
authored
support peft model quantization with SmoothQuant (#1282)
Peft model will use below arch: Linears in Linear. This pull request supports this arch with smoothquant. ``` (v): Linear( in_features=32, out_features=32, bias=False (lora_dropout): ModuleDict( (default): Dropout(p=0.1, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=32, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=32, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ``` BTW, when IPEX version<=1.13, HistogramObserver doesn't support asym scheme, the zero_point is 0 for asym uint8, while the MinMaxObserver works well. Also, IPEX SmoothQuant Observer can only use save/load_qconf_summary once. The save_qconf_summary API will freeze the scale used in model and calibration won't work anymore. The load_qconf_summary will overwrite the scales used in model but only work in the first call. Here we implement normal observer to workaround this issue. --------- Signed-off-by: changwangss <chang1.wang@intel.com> Signed-off-by: Xin He <xin3.he@intel.com> Signed-off-by: y <xin3.he@intel.com> Signed-off-by: chensuyue <suyue.chen@intel.com> Co-authored-by: changwangss <chang1.wang@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: chen, suyue <suyue.chen@intel.com>
1 parent 21668df commit 5e21b70

File tree

8 files changed

+15163
-165
lines changed

8 files changed

+15163
-165
lines changed

.azure-pipelines/scripts/ut/env_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ pip install horovod
9292
pip install transformers
9393

9494
if [[ $(echo "${test_case}" | grep -c "others") != 0 ]];then
95-
pip install tf_slim xgboost accelerate==0.21.0
95+
pip install tf_slim xgboost accelerate==0.21.0 peft
9696
elif [[ $(echo "${test_case}" | grep -c "nas") != 0 ]]; then
9797
pip install dynast==1.6.0rc1
9898
elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then

neural_compressor/adaptor/pytorch.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,7 @@ def _apply_pre_optimization(self, model, tune_cfg, recover=False):
18331833
absorb_layer = op_name
18341834
absorbed_layer = info["absorbed_layer"]
18351835
input_minmax = info["input_minmax"]
1836-
weight_max = info["weight_max"]
1836+
weight_max = info["weight_max"].clamp(min=1e-5)
18371837
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
18381838
input_power = torch.pow(abs_input_max, alpha)
18391839
weight_power = torch.pow(weight_max, 1 - alpha)
@@ -1858,11 +1858,12 @@ def qdq_quantize(self, model, tune_cfg):
18581858
"""
18591859
q_model = model._model
18601860
from .torch_utils.model_wrapper import QDQLinear, SQLinearWrapper
1861-
from .torch_utils.util import fetch_module, set_module
1861+
from .torch_utils.smooth_quant import get_module, set_module
18621862

18631863
smoothquant_scale_info = {}
18641864
fallback_op_name_list = []
18651865
stats_result = {}
1866+
stats_result["Linear(failed when SQ)"] = {"INT8(QDQ)": 0, "BF16": 0, "FP32": 0}
18661867
for (op_name, op_type), qconfig in tune_cfg["op"].items():
18671868
if op_type == "Linear" and qconfig["weight"]["dtype"] != "int8":
18681869
fallback_op_name_list.append(op_name)
@@ -1876,13 +1877,16 @@ def qdq_quantize(self, model, tune_cfg):
18761877
alpha = info["alpha"]
18771878
absorbed_layer = info["absorbed_layer"]
18781879
input_minmax = info["input_minmax"]
1879-
weight_max = info["weight_max"]
1880+
weight_max = info["weight_max"].clamp(min=1e-5)
18801881
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
18811882
input_power = torch.pow(abs_input_max, alpha)
18821883
weight_power = torch.pow(weight_max, 1 - alpha)
18831884
scale = torch.clip(input_power / weight_power, min=1e-5)
1885+
if torch.isnan(scale).any() or torch.isinf(scale).any():
1886+
stats_result["Linear(failed when SQ)"]["FP32"] += 1
1887+
continue # for peft model,lora_B weights is 0.
18841888
for op_name in absorbed_layer:
1885-
module = fetch_module(q_model, op_name)
1889+
module = get_module(q_model, op_name)
18861890
new_module = SQLinearWrapper(module, 1.0 / scale, input_minmax, alpha)
18871891
set_module(q_model, op_name, new_module)
18881892
logger.debug(f"Current SmoothQuant alpha of {op_name} is {alpha}")
@@ -2858,7 +2862,7 @@ def _dump_model_op_stats(self, tune_cfg):
28582862
output_data, header="Mixed Precision Statistics", field_names=["Op Type", "Total", "INT8", "BF16", "FP32"]
28592863
).print_stat()
28602864

2861-
def _cfg_to_qconfig(self, tune_cfg):
2865+
def _cfg_to_qconfig(self, tune_cfg, smooth_quant=False):
28622866
"""Convert tune configure to quantization config for each op.
28632867
28642868
Args:
@@ -2949,7 +2953,7 @@ def _cfg_to_qconfig(self, tune_cfg):
29492953
else:
29502954
op_infos = copy.deepcopy(self.op_infos_from_cfgs)
29512955
self.cfgs = torch_utils.util.check_cfg_and_qconfig(
2952-
tune_cfg["op"], self.cfgs, op_infos, self.output_tensor_id_op_name
2956+
tune_cfg["op"], self.cfgs, op_infos, self.output_tensor_id_op_name, smooth_quant
29532957
)
29542958

29552959
with open(self.ipex_config_path, "w") as write_f:
@@ -3112,7 +3116,7 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
31123116
smooth_quant_args = self.recipes.get("smooth_quant_args", {})
31133117
folding = smooth_quant_args.get("folding", False)
31143118
if not folding:
3115-
if self.sq_minmax_init:
3119+
if self.sq_minmax_init or self.version.release >= Version("2.2").release:
31163120
from torch.ao.quantization.observer import MinMaxObserver
31173121

31183122
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
@@ -3268,19 +3272,20 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
32683272
if sq_max_info:
32693273
smoothquant_scale_info = {}
32703274
from .torch_utils.model_wrapper import SQLinearWrapper
3271-
from .torch_utils.util import fetch_module
3275+
from .torch_utils.smooth_quant import get_module
32723276

32733277
for _, info in sq_max_info.items():
32743278
alpha = info["alpha"]
32753279
absorbed_layer = info["absorbed_layer"]
32763280
input_minmax = info["input_minmax"]
3277-
weight_max = info["weight_max"]
3281+
# for peft model,lora_B weights is 0.
3282+
weight_max = info["weight_max"].clamp(min=1e-5)
32783283
abs_input_max = torch.max(torch.abs(input_minmax[0]), torch.abs(input_minmax[1]))
32793284
input_power = torch.pow(abs_input_max, alpha)
32803285
weight_power = torch.pow(weight_max, 1 - alpha)
32813286
scale = torch.clip(input_power / weight_power, min=1e-5)
32823287
for op_name in absorbed_layer:
3283-
module = copy.deepcopy(fetch_module(q_model._model, op_name))
3288+
module = copy.deepcopy(get_module(q_model._model, op_name))
32843289
new_module = SQLinearWrapper(module, 1.0 / scale, input_minmax, alpha)
32853290
weight_scale = new_module._get_weight_scale()
32863291
smoothquant_scale_info[op_name] = {
@@ -3296,7 +3301,7 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
32963301
# Check save_qconf_summary part is a workaround for IPEX bug.
32973302
# Sometimes the prepared model from get_op_capablitiy loss this attribute
32983303
if not hasattr(model._model, "save_qconf_summary") or not hasattr(model._model, "load_qconf_summary"):
3299-
if self.sq_minmax_init:
3304+
if self.sq_minmax_init or self.version.release >= Version("2.2").release:
33003305
from torch.ao.quantization.observer import MinMaxObserver
33013306

33023307
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(
@@ -3313,10 +3318,14 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
33133318
model._model, static_qconfig, example_inputs=self.example_inputs, inplace=inplace
33143319
)
33153320

3316-
# TODO: update_sq_scale is used to update observer, should fuse in _cfg_to_qconfig
3321+
# The IPEX SmoothQuant observer can only use save/load_qconf_summary once.
3322+
# The save_qconf_summary API will freeze the scale used in model and calibration won't work anymore.
3323+
# The load_qconf_summary will overwrite the scales used in model but only work in the first call.
3324+
# Here, we use INC collected scale for Linear and set normal observer instead of SQObserver \
3325+
# to make sure calibration works for other ops, like add, bmm.
33173326
from .torch_utils.util import update_sq_scale
33183327

3319-
self._cfg_to_qconfig(tune_cfg)
3328+
self._cfg_to_qconfig(tune_cfg, smooth_quant=True)
33203329
update_sq_scale(self.ipex_config_path, smoothquant_scale_info)
33213330
model._model.load_qconf_summary(qconf_summary=self.ipex_config_path)
33223331

@@ -3337,10 +3346,6 @@ def qdq_quantize(self, model, q_model, tune_cfg, dataloader, q_func):
33373346
+ "using scale info from SmoothQuant for Linear and "
33383347
+ "one iter calibration for other ops."
33393348
)
3340-
# update ipex_config.json with smoothquant_scale_info
3341-
model._model.save_qconf_summary(qconf_summary=self.ipex_config_path)
3342-
update_sq_scale(self.ipex_config_path, smoothquant_scale_info)
3343-
model._model.load_qconf_summary(qconf_summary=self.ipex_config_path)
33443349

33453350
self._ipex_post_quant_process(model, q_model, dataloader, inplace=inplace)
33463351

neural_compressor/adaptor/pytorch_ipex.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848
},
4949
'activation': {
5050
'dtype': ['uint8'],
51-
'scheme': ['asym'],
51+
'scheme': ['asym', 'sym'],
5252
'granularity': ['per_tensor'],
53-
'algorithm': ['minmax']
53+
'algorithm': ['minmax', 'kl']
5454
}
5555
},
5656
},

neural_compressor/adaptor/torch_utils/smooth_quant.py

Lines changed: 40 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ def get_module(model, key):
182182
for name in name_list:
183183
if hasattr(module, name):
184184
module = getattr(module, name)
185+
elif hasattr(module, "sq_linear"): # for peft models
186+
module = getattr(module, "sq_linear")
187+
module = getattr(module, name)
188+
elif hasattr(module, "orig_layer"): # for peft models and auto alpha
189+
module = getattr(module, "orig_layer")
190+
module = getattr(module, name)
185191
else:
186192
module = module
187193
return module
@@ -200,8 +206,19 @@ def set_module(model, key, new_module):
200206
for name in name_list[:-1]:
201207
if hasattr(module, name):
202208
module = getattr(module, name)
209+
elif hasattr(module, ("sq_linear")): # for peft models that Linears are contained in Linear
210+
module = getattr(module, "sq_linear")
211+
module = getattr(module, name)
212+
elif hasattr(module, ("orig_layer")): # for peft models and auto alpha
213+
module = getattr(module, "orig_layer")
214+
module = getattr(module, name)
203215
else:
204216
module = module
217+
218+
if hasattr(module, "sq_linear") and name_list[-1] != "sq_linear": # for peft models
219+
module = getattr(module, "sq_linear")
220+
if hasattr(module, "orig_layer") and name_list[-1] != "orig_layer": # for peft models and auto alpha
221+
module = getattr(module, "orig_layer")
205222
setattr(module, name_list[-1], new_module)
206223

207224

@@ -222,7 +239,7 @@ def cal_scale(input_max, weights, alpha, scale_type="orig"):
222239
class WrapperLayer(torch.nn.Module):
223240
def __init__(self, layer, input_min, input_max, save_q_input=False):
224241
super(WrapperLayer, self).__init__()
225-
self.orig_layer = layer
242+
self.add_module("orig_layer", layer) # set orig_layer in get/set_module
226243
self.quant = False
227244
self.q_input = None
228245
self.fp32_output = None
@@ -281,7 +298,7 @@ class TorchSmoothQuant:
281298
to recover the weights if needed
282299
"""
283300

284-
def __init__(self, model, dataloader, example_inputs=None, q_func=None, traced_model=None):
301+
def __init__(self, model, dataloader=None, example_inputs=None, q_func=None, traced_model=None):
285302
"""
286303
:param model: Torch model :param dataloader: Calibration dataloader :param traced_model: A specific model
287304
shares the same architecture as the model and could be traced by torch.jit. If not supplied, we use model
@@ -372,7 +389,7 @@ def _calibrate(self, absorb_to_layer, calib_iter, percentile):
372389
##hook all the module
373390
hook_modules = {}
374391
for n, module in self.model.named_modules():
375-
if module.__class__.__name__.split(".")[-1] in self.op_types:
392+
if isinstance(module, tuple(self.op_types)):
376393
hook_modules[n] = module
377394

378395
self._add_min_max_observer(hook_modules, percentile)
@@ -547,6 +564,8 @@ def _cal_scales(self, absorb_to_layer, input_maxes, alpha=0.5, tuning=False):
547564
alpha_tmp = alpha
548565
elif isinstance(alpha, dict):
549566
alpha_tmp = alpha[key]
567+
else:
568+
alpha_tmp = alpha
550569
if alpha_tmp < 0:
551570
scale = torch.ones((1), device=self.device)
552571
else:
@@ -670,7 +689,7 @@ def _get_sq_layer_names(self):
670689
def _get_all_hook_module_names(self):
671690
module_names = []
672691
for n, module in self.model.named_modules():
673-
if module.__class__.__name__.split(".")[-1] in self.op_types:
692+
if isinstance(module, tuple(self.op_types)):
674693
module_names.append(n)
675694
return module_names
676695

@@ -680,25 +699,27 @@ def _qdq_model_wrapper_for_auto(self, save_q_input=False):
680699
module_names = self._get_all_hook_module_names()
681700
self.to_unwrap_module_names = module_names
682701
for name in module_names:
702+
if name not in self.input_mins: # skip module if it's not used in calibration
703+
continue
683704
module = get_module(self.model, name)
684-
set_module(
685-
self.model,
686-
name,
687-
WrapperLayer(module, self.input_mins[name], self.input_maxes[name], save_q_input=save_q_input),
688-
)
705+
new_module = WrapperLayer(module, self.input_mins[name], self.input_maxes[name], save_q_input=save_q_input)
706+
set_module(self.model, name, new_module)
689707

690708
def _qdq_model_unwrapper_for_auto(self):
691709
module_names = self.to_unwrap_module_names
692710
for name in module_names:
693711
module = get_module(self.model, name)
694-
# print(name, flush=True)
712+
if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration
713+
continue
695714
set_module(self.model, name, module.orig_layer)
696715

697716
def _change_qdq_for_auto(self, enable=True):
698717
module_names = self._get_all_hook_module_names()
699718
for name in module_names:
700719
name = name.split(".orig_layer")[0]
701720
module = get_module(self.model, name)
721+
if not hasattr(module, "orig_layer"): # skip module if it's not used in calibration
722+
continue
702723
if enable:
703724
module.enable_quant()
704725
else:
@@ -921,7 +942,7 @@ def transform(
921942
alpha=0.5,
922943
folding=False,
923944
percentile=100,
924-
op_types=["Linear", "Conv2d"],
945+
op_types=[torch.nn.Linear, torch.nn.Conv2d],
925946
scales_per_op=False,
926947
calib_iter=100,
927948
auto_alpha_args={"alpha_min": 0.0, "alpha_max": 1.0, "alpha_step": 0.1, "shared_criterion": "mean"},
@@ -953,12 +974,13 @@ def transform(
953974
self.recover()
954975
need_calibration = self._check_need_calibration(alpha, percentile, op_types, scales_per_op, calib_iter)
955976
with torch.no_grad():
977+
str_op_types = [i.__name__ for i in op_types]
956978
input_maxes_abs = self.input_maxes_abs
957979
if need_calibration: ##avoid multiple calibaration during tuning if the only difference is alpha
958980
if self.insert_mul:
959-
self.self_absorb_layers = self._get_all_layer_names() # TODO: only support linear now.
981+
self.self_absorb_layers = self._get_all_layer_names(op_types) # TODO: only support linear now.
960982
# fetch modules with the same input
961-
group_modules = self._trace(op_types, skip_unsupported_layers=False)
983+
group_modules = self._trace(str_op_types, skip_unsupported_layers=False)
962984
if group_modules is not None:
963985
# use one input for qkv
964986
for k, v in group_modules.items():
@@ -969,7 +991,7 @@ def transform(
969991
logger.debug(f"self_absorb_layers:{self.self_absorb_layers}")
970992
if self.allow_absorb:
971993
self.absorb_to_layer, no_absorb_layers = self._trace(
972-
op_types
994+
str_op_types
973995
) ##TODO we need to insert mul layer for no_absorb_layers later
974996
if self.absorb_to_layer is None and no_absorb_layers is None:
975997
return self.model
@@ -1061,28 +1083,18 @@ def recover(self):
10611083
self.weight_scale_info = {} ##clear the data
10621084
self.absorb_scales_info = {}
10631085

1064-
def _get_all_layer_names(self, op_types=["Linear"]):
1086+
def _get_all_layer_names(self, op_types=[torch.nn.Linear]):
10651087
"""Try the model to find the layers which can be smooth quantized.
10661088
10671089
:param op_types: The op types to be smooth quantized
10681090
:return:
10691091
self_absorb_layer: A dict, absorb layer name (itself): layers to be smooth quantized
10701092
"""
10711093
self_absorb_layer = {}
1094+
op_types = [torch.nn.Linear] # TODO: only support SQLinearWrapper
10721095
for name, module in self.model.named_modules():
1073-
for op_type in op_types:
1074-
if op_type == str(module.__class__.__name__):
1075-
self_absorb_layer[name] = [name]
1076-
# remove duplicate Linear if Linear is wrapped by Linear
1077-
key_list = list(self_absorb_layer.keys())
1078-
key_list.sort()
1079-
duplicate_list = []
1080-
for i, k1 in enumerate(key_list):
1081-
for k2 in key_list[i + 1 :]:
1082-
if k1 in k2:
1083-
duplicate_list.append(k1)
1084-
for i in duplicate_list:
1085-
self_absorb_layer.pop(i)
1096+
if isinstance(module, tuple(op_types)):
1097+
self_absorb_layer[name] = [name]
10861098
return self_absorb_layer
10871099

10881100
def _get_example_input(self):
@@ -1334,46 +1346,3 @@ def remove_unsupported_layers(self, model, absorb_to_layer, no_absorb_layers):
13341346
if supported:
13351347
res[key] = absorb_to_layer[key]
13361348
return res
1337-
1338-
1339-
def update_sq_scale(ipex_config_path, smoothquant_scale_info):
1340-
"""Update ipex_config.json with smoothquant scale info generated by our algorithm.
1341-
1342-
Args:
1343-
ipex_config_path (str): a path to temporary ipex_config.json file.
1344-
smoothquant_scale_info (dict): a dict contains smoothquant scale info.
1345-
"""
1346-
with open(ipex_config_path, "r") as f:
1347-
ipex_config = json.load(f)
1348-
for module_name, v in ipex_config.items():
1349-
if "q_op_infos" in v and v["q_op_infos"]:
1350-
for op_num, v1 in v["q_op_infos"].items():
1351-
# update alpha data instead of updating weight scale
1352-
op_name = v1["fqn"] # fqn always exists even it's empty.
1353-
if op_name in smoothquant_scale_info:
1354-
# observers were overridden by the fallback step, setting it back.
1355-
v1["activation_observer"] = {
1356-
"name": "SmoothQuantActivationObserver",
1357-
"smooth_quant_enabled": False,
1358-
"dtype": "torch.quint8",
1359-
"qscheme": "torch.per_tensor_affine",
1360-
"reduce_range": False,
1361-
"quant_min": 0,
1362-
"quant_max": 255,
1363-
"alpha": smoothquant_scale_info[op_name]["alpha"],
1364-
}
1365-
v1["weight_observer"] = {
1366-
"name": "SmoothQuantWeightObserver",
1367-
"smooth_quant_enabled": False,
1368-
"dtype": "torch.qint8",
1369-
"qscheme": "torch.per_channel_symmetric",
1370-
"reduce_range": False,
1371-
"quant_min": -128,
1372-
"quant_max": 127,
1373-
"alpha": smoothquant_scale_info[op_name]["alpha"], # only update alpha
1374-
}
1375-
f.close()
1376-
# overwrite ipex_config_path
1377-
with open(ipex_config_path, "w") as f1:
1378-
json.dump(ipex_config, f1, indent=4)
1379-
f1.close()

0 commit comments

Comments
 (0)