Skip to content

Commit 49e950e

Browse files
yintong-lupre-commit-ci[bot]xin3he
authored andcommitted
[regression fix] sq enhance calibration part (#1276)
Signed-off-by: Lu, Yintong <yintong.lu@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: xinhe <xin3.he@intel.com> (cherry picked from commit e6eda31)
1 parent bd9f093 commit 49e950e

File tree

2 files changed

+102
-95
lines changed

2 files changed

+102
-95
lines changed

neural_compressor/adaptor/torch_utils/smooth_quant.py

Lines changed: 83 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def forward_wrapper(model, input, device=torch.device("cpu")):
6060
if isinstance(input, dict) or isinstance(input, UserDict):
6161
output = model(**input)
6262
elif isinstance(input, list) or isinstance(input, tuple):
63-
output = model(*input)
63+
try:
64+
output = model(*input)
65+
except:
66+
output = model(input)
6467
else:
6568
output = model(input)
6669
return output
@@ -295,8 +298,6 @@ def __init__(self, model, dataloader, example_inputs=None, q_func=None, traced_m
295298
self.dataloader = dataloader
296299
self.example_inputs = example_inputs
297300
self.q_func = q_func
298-
self.input_values = {}
299-
self.output_values = {}
300301
self.input_maxes = {}
301302
self.input_mins = {}
302303
self.input_maxes_abs = {}
@@ -325,10 +326,6 @@ def _save_input_pc_hook(self, name, percentile=100):
325326
:return: A hook function."""
326327

327328
def save_input_hook(module, inputs, outputs):
328-
if name not in self.input_maxes.keys():
329-
self.input_maxes[name] = []
330-
self.input_mins[name] = []
331-
self.input_maxes_abs[name] = []
332329
input = inputs[0]
333330
##TODO check input channel is correct
334331
if len(module.weight.shape) == 4: ##conv3d or conv1d not supported now, need better way
@@ -339,43 +336,16 @@ def save_input_hook(module, inputs, outputs):
339336
k_index = int(input.shape[0] * percentile / 100)
340337
res, _ = torch.kthvalue(torch.abs(input), k_index, dim=0)
341338
##res = torch.max(torch.abs(input),dim=0)[0]
342-
self.input_maxes_abs[name].append(res)
343-
self.input_maxes[name].append(max_tensor)
344-
self.input_mins[name].append(min_tensor)
339+
if name not in self.input_maxes.keys():
340+
self.input_mins[name], self.input_maxes[name] = min_tensor, max_tensor
341+
self.input_maxes_abs[name] = res
342+
else:
343+
self.input_mins[name] = torch.min(self.input_mins[name], min_tensor)
344+
self.input_maxes[name] = torch.max(self.input_maxes[name], max_tensor)
345+
self.input_maxes_abs[name] = torch.max(self.input_maxes_abs[name], res)
345346

346347
return save_input_hook
347348

348-
def _save_input_output_hook(self, name):
349-
"""
350-
A forward hook to save input and output values of a module
351-
param name: the module name
352-
return: A hook function
353-
"""
354-
355-
def save_input_output_hook(module, inputs, outputs):
356-
input = inputs[0]
357-
cnt = 32
358-
if name in self.input_values.keys() and len(self.input_values[name]) < cnt:
359-
self.input_values[name].append(input)
360-
self.output_values[name].append(outputs)
361-
if name not in self.input_values.keys():
362-
self.input_values[name] = [input] ##TODO save more,like 8
363-
self.output_values[name] = [outputs] ##TODO do not save output
364-
365-
return save_input_output_hook
366-
367-
def _add_input_output_observer(self):
368-
input_output_modules = {}
369-
hook_layer_names = []
370-
for key in self.absorb_to_layer:
371-
hook_layer_names += self.absorb_to_layer[key]
372-
for name in hook_layer_names:
373-
input_output_modules[name] = get_module(self.model, name)
374-
for key in input_output_modules.keys():
375-
hook_func = self._save_input_output_hook(key)
376-
hook_handle = input_output_modules[key].register_forward_hook(hook_func)
377-
self.hook_handles.append(hook_handle)
378-
379349
def _add_min_max_observer(self, modules, percentile=100):
380350
"""
381351
:param modules: the modules which the observer will insert to
@@ -393,7 +363,7 @@ def _remove_observer(self):
393363
for hook_handle in self.hook_handles:
394364
hook_handle.remove()
395365

396-
def _calibrate(self, absorb_to_layer, calib_iter, percentile, save_input_output=False):
366+
def _calibrate(self, absorb_to_layer, calib_iter, percentile):
397367
"""
398368
:param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer
399369
:param calib_iter: Data size for calibration
@@ -406,8 +376,6 @@ def _calibrate(self, absorb_to_layer, calib_iter, percentile, save_input_output=
406376
hook_modules[n] = module
407377

408378
self._add_min_max_observer(hook_modules, percentile)
409-
if save_input_output:
410-
self._add_input_output_observer()
411379

412380
self._dump_min_max(calib_iter=calib_iter)
413381
self._remove_observer()
@@ -423,15 +391,6 @@ def _dump_min_max(self, calib_iter=100):
423391
else:
424392
assert self.dataloader, "Please set dataloader for calibration."
425393
model_forward(self.model, self.dataloader, calib_iter, self.device)
426-
##stack
427-
for key in self.input_maxes.keys():
428-
max_val = self.input_maxes[key]
429-
max_val = torch.stack(max_val, dim=0)
430-
min_val = self.input_mins[key]
431-
min_val = torch.stack(min_val, dim=0)
432-
self.input_maxes[key] = torch.max(max_val, dim=0)[0]
433-
self.input_mins[key] = torch.min(min_val, dim=0)[0]
434-
self.input_maxes_abs[key] = torch.max(torch.stack(self.input_maxes_abs[key], dim=0), dim=0)[0]
435394

436395
def _reshape_in_channel_to_last(self, layer_name):
437396
"""Move the input channel to the last dim
@@ -877,43 +836,76 @@ def _auto_tune_alpha_new(
877836
if not self.dataloader:
878837
self._qdq_model_unwrapper_for_auto()
879838
return best_alphas
880-
881-
for idx, input in enumerate(self.dataloader):
882-
if isinstance(input, (tuple, list)):
883-
input = input[0]
884-
best_alphas_per_module = best_alphas
885-
if isinstance(best_alphas, dict):
886-
for key in self.absorb_to_layer.keys():
887-
layer_names = self.absorb_to_layer[key]
888-
for layer_name in layer_names:
889-
best_alphas_per_module[layer_name] = best_alphas_per_module[key]
890-
891-
loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
892-
if loss_alphas == {}:
893-
loss_alphas = loss_tmp
894-
else:
895-
for key in loss_alphas.keys():
896-
cur_loss = loss_alphas[key]
897-
for alpha_key in cur_loss.keys():
898-
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
899-
if isinstance(input, list):
900-
input = move_input_to_device(input, self.device)
901-
for inp in input:
902-
cnt += inp.shape[0]
903-
else:
904-
cnt += input.shape[0]
905-
906-
if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
907-
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
908-
for key in best_alphas.keys():
909-
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
910-
absorb_input_scales, weight_scales = self._cal_scales(
911-
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
912-
)
913-
self._update_scales_for_auto(absorb_input_scales, weight_scales)
914-
loss_alphas = {} ##TODO check need to remove this one
915-
if cnt >= auto_calib_iter:
916-
break
839+
try:
840+
for input, label in self.dataloader:
841+
best_alphas_per_module = best_alphas
842+
if isinstance(best_alphas, dict):
843+
for key in self.absorb_to_layer.keys():
844+
layer_names = self.absorb_to_layer[key]
845+
for layer_name in layer_names:
846+
best_alphas_per_module[layer_name] = best_alphas_per_module[key]
847+
848+
loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
849+
if loss_alphas == {}:
850+
loss_alphas = loss_tmp
851+
else:
852+
for key in loss_alphas.keys():
853+
cur_loss = loss_alphas[key]
854+
for alpha_key in cur_loss.keys():
855+
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
856+
if isinstance(input, list):
857+
input = move_input_to_device(input, self.device)
858+
for inp in input:
859+
cnt += inp.shape[0]
860+
else:
861+
cnt += input.shape[0]
862+
863+
if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
864+
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
865+
for key in best_alphas.keys():
866+
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
867+
absorb_input_scales, weight_scales = self._cal_scales(
868+
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
869+
)
870+
self._update_scales_for_auto(absorb_input_scales, weight_scales)
871+
loss_alphas = {} ##TODO check need to remove this one
872+
if cnt >= auto_calib_iter:
873+
break
874+
except:
875+
for input in self.dataloader:
876+
best_alphas_per_module = best_alphas
877+
if isinstance(best_alphas, dict):
878+
for key in self.absorb_to_layer.keys():
879+
layer_names = self.absorb_to_layer[key]
880+
for layer_name in layer_names:
881+
best_alphas_per_module[layer_name] = best_alphas_per_module[key]
882+
883+
loss_tmp = self._get_one_sample_auto_loss(input, alpha_space, best_alphas_per_module, input_maxes)
884+
if loss_alphas == {}:
885+
loss_alphas = loss_tmp
886+
else:
887+
for key in loss_alphas.keys():
888+
cur_loss = loss_alphas[key]
889+
for alpha_key in cur_loss.keys():
890+
cur_loss[alpha_key] += loss_tmp[key][alpha_key]
891+
if isinstance(input, list):
892+
input = move_input_to_device(input, self.device)
893+
for inp in input:
894+
cnt += inp.shape[0]
895+
else:
896+
cnt += input.shape[0]
897+
898+
if cnt % multiply_factor == 0 and (auto_calib_iter - cnt) >= multiply_factor:
899+
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
900+
for key in best_alphas.keys():
901+
logger.info(f"{cnt // multiply_factor},{key}:{best_alphas[key]}")
902+
absorb_input_scales, weight_scales = self._cal_scales(
903+
self.absorb_to_layer, input_maxes, best_alphas, tuning=True
904+
)
905+
self._update_scales_for_auto(absorb_input_scales, weight_scales)
906+
loss_alphas = {} ##TODO check need to remove this one
907+
if cnt >= auto_calib_iter:
908+
break
917909

918910
best_alphas = self._get_best_alpha(self.absorb_to_layer, loss_alphas, shared_criterion)
919911
for key in best_alphas.keys():
@@ -995,11 +987,8 @@ def transform(
995987
"you could set torchscript to True "
996988
)
997989
return self.model
998-
save_input_output = False if alpha == "auto" else True
999-
# if alpha == "auto":
1000-
# save_input_output = True
1001990

1002-
input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile, save_input_output)
991+
input_maxes_abs = self._calibrate(self.absorb_to_layer, calib_iter, percentile)
1003992

1004993
# Check if input_maxes match self.absorb_to_layer
1005994
# (due to self._get_all_layer_names use layer tree instead of forward_path)
@@ -1042,7 +1031,6 @@ def transform(
10421031
else:
10431032
logger.warning(" Could not get example input, equivelancy check is skipped")
10441033

1045-
self.input_values, self.output_values = {}, {}
10461034
return self.model
10471035

10481036
def output_is_equal(self, out1, out2, atol=1e-04):

test/algorithm/test_smooth_quant.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,5 +1265,24 @@ def calib_func(prepared_model):
12651265
self.assertEqual(indices[2], torch.tensor([504]))
12661266

12671267

1268+
class TestMemoryUsage(unittest.TestCase):
1269+
def test_sq_auto_mem_usage(self):
1270+
import psutil
1271+
1272+
data = psutil.virtual_memory()
1273+
cpu_process = psutil.Process()
1274+
p = psutil.Process(cpu_process.pid)
1275+
mem_use0 = p.memory_info().rss / (1024**3)
1276+
model = transformers.AutoModelForCausalLM.from_pretrained(
1277+
"facebook/opt-125m",
1278+
torchscript=True,
1279+
)
1280+
sq = TorchSmoothQuant(model, LLMCalibDataloader())
1281+
sq.transform(alpha="auto", calib_iter=0, folding=False)
1282+
mem_use1 = p.memory_info().rss / (1024**3)
1283+
logger.info(f"The memory usage of this ut is {mem_use1 - mem_use0} GBs.")
1284+
assert (mem_use1 - mem_use0) <= 2.0
1285+
1286+
12681287
if __name__ == "__main__":
12691288
unittest.main()

0 commit comments

Comments
 (0)