Skip to content

Commit c295a7f

Browse files
authored
Enhancement memory usage for PyTorch quantization (#541)
Signed-off-by: Cheng, Penghui <penghui.cheng@intel.com> Signed-off-by: Xin He <xin3.he@intel.com> Signed-off-by: Lv, Liang1 <liang1.lv@intel.com>
1 parent 6e10efd commit c295a7f

File tree

11 files changed

+384
-187
lines changed

11 files changed

+384
-187
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 275 additions & 118 deletions
Large diffs are not rendered by default.

neural_compressor/conf/config.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,7 @@ def map_pyconfig_to_cfg(self, pythonic_config):
14011401
'model.domain': pythonic_config.quantization.domain,
14021402
'quantization.recipes': pythonic_config.quantization.recipes,
14031403
'quantization.approach': pythonic_config.quantization.approach,
1404+
'quantization.example_inputs': pythonic_config.quantization.example_inputs,
14041405
'quantization.calibration.sampling_size':
14051406
pythonic_config.quantization.calibration_sampling_size,
14061407
'quantization.optype_wise': pythonic_config.quantization.op_type_list,
@@ -1429,7 +1430,7 @@ def map_pyconfig_to_cfg(self, pythonic_config):
14291430
if st_key in st_kwargs:
14301431
st_val = st_kwargs[st_key]
14311432
mapping.update({'tuning.strategy.' + st_key: st_val})
1432-
1433+
14331434
if pythonic_config.distillation is not None:
14341435
mapping.update({
14351436
'distillation.train.criterion': pythonic_config.distillation.criterion,
@@ -1458,7 +1459,6 @@ def map_pyconfig_to_cfg(self, pythonic_config):
14581459
if pythonic_config.benchmark.outputs != []:
14591460
mapping.update({'model.outputs': pythonic_config.benchmark.outputs})
14601461
mapping.update({
1461-
'model.backend': pythonic_config.benchmark.backend,
14621462
'evaluation.performance.warmup': pythonic_config.benchmark.warmup,
14631463
'evaluation.performance.iteration': pythonic_config.benchmark.iteration,
14641464
'evaluation.performance.configs.cores_per_instance':
@@ -1478,6 +1478,16 @@ def map_pyconfig_to_cfg(self, pythonic_config):
14781478
'evaluation.accuracy.configs.intra_num_of_threads':
14791479
pythonic_config.benchmark.intra_num_of_threads,
14801480
})
1481+
if "model.backend" not in mapping:
1482+
mapping.update({
1483+
'model.backend': pythonic_config.benchmark.backend,
1484+
})
1485+
else:
1486+
if mapping['model.backend'] == 'default' and \
1487+
pythonic_config.benchmark.backend != 'default':
1488+
mapping.update({
1489+
'model.backend': pythonic_config.benchmark.backend,
1490+
})
14811491

14821492
if "model.backend" not in mapping:
14831493
mapping.update({

neural_compressor/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ def __init__(self,
388388
max_trials=100,
389389
performance_only=False,
390390
reduce_range=None,
391+
example_inputs=None,
391392
excluded_precisions=[],
392393
quant_level=1,
393394
accuracy_criterion=accuracy_criterion,
@@ -428,6 +429,7 @@ def __init__(self,
428429
max_trials: max tune times. default value is 100. Combine with timeout field to decide when to exit
429430
performance_only: whether do evaluation
430431
reduce_range: whether use 7 bit
432+
example_inputs: used to trace PyTorch model with torch.jit/torch.fx
431433
excluded_precisions: precisions to be excluded, support 'bf16'
432434
quant_level: support 0 and 1, 0 is conservative strategy, 1 is basic(default) or user-specified strategy
433435
accuracy_criterion: accuracy constraint settings
@@ -455,6 +457,7 @@ def __init__(self,
455457
self.calibration_sampling_size = calibration_sampling_size
456458
self.quant_level = quant_level
457459
self.use_distributed_tuning=use_distributed_tuning
460+
self._example_inputs = example_inputs
458461

459462
@property
460463
def domain(self):
@@ -766,6 +769,16 @@ def inputs(self, inputs):
766769
if check_value('inputs', inputs, str):
767770
self._inputs = inputs
768771

772+
@property
773+
def example_inputs(self):
774+
"""Get strategy_kwargs."""
775+
return self._example_inputs
776+
777+
@example_inputs.setter
778+
def example_inputs(self, example_inputs):
779+
"""Set example_inputs."""
780+
self._example_inputs = example_inputs
781+
769782

770783
class TuningCriterion:
771784
"""Class for Tuning Criterion.

neural_compressor/model/torch_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ def save(self, root=None):
793793
logger.info("Save config file of quantized model to {}.".format(root))
794794
except IOError as e:
795795
logger.error("Fail to save configure file and weights due to {}.".format(e))
796-
796+
797797
if isinstance(self.model, torch.jit._script.RecursiveScriptModule):
798798
self.model.save(os.path.join(root, "best_model.pt"))
799-
799+

neural_compressor/strategy/strategy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -995,7 +995,8 @@ def _create_path(self, custom_path, filename):
995995
def _set_framework_info(self, q_dataloader, q_func=None):
996996
framework_specific_info = {'device': self.cfg.device,
997997
'approach': self.cfg.quantization.approach,
998-
'random_seed': self.cfg.tuning.random_seed}
998+
'random_seed': self.cfg.tuning.random_seed,
999+
'performance_only': self.cfg.tuning.exit_policy.performance_only,}
9991000
framework = self.cfg.model.framework.lower()
10001001
framework_specific_info.update({'backend': self.cfg.model.get('backend', 'default')})
10011002
framework_specific_info.update({'format': self.cfg.model.get('quant_format', 'default')})
@@ -1010,7 +1011,6 @@ def _set_framework_info(self, q_dataloader, q_func=None):
10101011
"outputs": self.cfg.model.outputs,
10111012
'workspace_path': self.cfg.tuning.workspace.path,
10121013
'recipes': self.cfg.quantization.recipes,
1013-
'performance_only': self.cfg.tuning.exit_policy.performance_only,
10141014
'use_bf16': self.cfg.use_bf16 if self.cfg.use_bf16 is not None else False})
10151015
for item in ['scale_propagation_max_pooling', 'scale_propagation_concat']:
10161016
if item not in framework_specific_info['recipes']:
@@ -1054,6 +1054,7 @@ def _set_framework_info(self, q_dataloader, q_func=None):
10541054
framework_specific_info.update(
10551055
{"default_qconfig": self.cfg['quantization']['op_wise']['default_qconfig']})
10561056
framework_specific_info.update({"q_func": q_func})
1057+
framework_specific_info.update({"example_inputs": self.cfg.quantization.example_inputs})
10571058
return framework, framework_specific_info
10581059

10591060
def _set_objectives(self):

neural_compressor/utils/pytorch.py

Lines changed: 35 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
201201
stat_dict['best_configure'] = tune_cfg
202202
else:
203203
logger.error("Unexpected checkpoint type:{}. \
204-
Only file dir/path or state_dict is acceptable")
204+
Only file dir/path or state_dict is acceptable")
205205

206206
if not isinstance(stat_dict, torch.jit._script.RecursiveScriptModule):
207207
assert 'best_configure' in stat_dict, \
@@ -223,17 +223,10 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
223223
logger.info("Finish load the model quantized by INC IPEX backend.")
224224
return q_model
225225

226-
try:
227-
q_model = copy.deepcopy(model)
228-
except Exception as e: # pragma: no cover
229-
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".
230-
format(repr(e)))
231-
q_model = model
232-
233226
if 'is_oneshot' in tune_cfg and tune_cfg['is_oneshot']:
234-
return _load_int8_orchestration(q_model, tune_cfg, stat_dict, example_inputs, **kwargs)
227+
return _load_int8_orchestration(model, tune_cfg, stat_dict, example_inputs, **kwargs)
235228

236-
q_model.eval()
229+
model.eval()
237230
approach_quant_mode = None
238231
if tune_cfg['approach'] == "post_training_dynamic_quant":
239232
approach_quant_mode = 'dynamic'
@@ -279,79 +272,79 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
279272
op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach'])
280273
fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg['approach'])
281274
if not tune_cfg['fx_sub_module_list']:
282-
tmp_model = q_model
275+
tmp_model = model
283276
if tune_cfg['approach'] == "quant_aware_training":
284-
q_model.train()
277+
model.train()
285278
if version > Version("1.12.1"): # pragma: no cover
286279
# pylint: disable=E1123
287-
q_model = prepare_qat_fx(q_model,
288-
fx_op_cfgs,
289-
prepare_custom_config=prepare_custom_config_dict,
290-
example_inputs=example_inputs)
280+
model = prepare_qat_fx(model,
281+
fx_op_cfgs,
282+
prepare_custom_config=prepare_custom_config_dict,
283+
example_inputs=example_inputs)
291284
else:
292-
q_model = prepare_qat_fx(q_model,
293-
fx_op_cfgs,
294-
prepare_custom_config_dict=prepare_custom_config_dict)
285+
model = prepare_qat_fx(model,
286+
fx_op_cfgs,
287+
prepare_custom_config_dict=prepare_custom_config_dict)
295288
else:
296289
if version > Version("1.12.1"): # pragma: no cover
297290
# pylint: disable=E1123
298-
q_model = prepare_fx(q_model,
299-
fx_op_cfgs,
300-
prepare_custom_config=prepare_custom_config_dict,
301-
example_inputs=example_inputs)
291+
model = prepare_fx(model,
292+
fx_op_cfgs,
293+
prepare_custom_config=prepare_custom_config_dict,
294+
example_inputs=example_inputs)
302295
else:
303-
q_model = prepare_fx(q_model,
304-
fx_op_cfgs,
305-
prepare_custom_config_dict=prepare_custom_config_dict)
296+
model = prepare_fx(model,
297+
fx_op_cfgs,
298+
prepare_custom_config_dict=prepare_custom_config_dict)
306299
if version > Version("1.12.1"): # pragma: no cover
307300
# pylint: disable=E1123
308-
q_model = convert_fx(q_model,
301+
model = convert_fx(model,
309302
convert_custom_config=convert_custom_config_dict)
310303
else:
311-
q_model = convert_fx(q_model,
304+
model = convert_fx(model,
312305
convert_custom_config_dict=convert_custom_config_dict)
313-
util.append_attr(q_model, tmp_model)
306+
util.append_attr(model, tmp_model)
314307
del tmp_model
315308
else:
316309
sub_module_list = tune_cfg['fx_sub_module_list']
317310
if tune_cfg['approach'] == "quant_aware_training":
318-
q_model.train()
311+
model.train()
319312
PyTorch_FXAdaptor.prepare_sub_graph(sub_module_list,
320313
fx_op_cfgs,
321-
q_model,
314+
model,
322315
prefix='',
323316
is_qat=True,
324317
example_inputs=example_inputs)
325318
else:
326319
PyTorch_FXAdaptor.prepare_sub_graph(sub_module_list,
327320
fx_op_cfgs,
328-
q_model,
321+
model,
329322
prefix='',
330323
example_inputs=example_inputs)
331-
PyTorch_FXAdaptor.convert_sub_graph(sub_module_list, q_model, prefix='')
324+
PyTorch_FXAdaptor.convert_sub_graph(sub_module_list, model, prefix='')
332325
else:
333326
if tune_cfg['approach'] == "post_training_dynamic_quant":
334327
op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach'])
335328
else:
336329
op_cfgs = _cfg_to_qconfig(tune_cfg)
337330

338-
_propagate_qconfig(q_model, op_cfgs, approach=tune_cfg['approach'])
331+
_propagate_qconfig(model, op_cfgs, approach=tune_cfg['approach'])
339332
# sanity check common API misusage
340-
if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()):
333+
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
341334
logger.warn("None of the submodule got qconfig applied. Make sure you "
342335
"passed correct configuration through `qconfig_dict` or "
343336
"by assigning the `.qconfig` attribute directly on submodules")
344337
if tune_cfg['approach'] != "post_training_dynamic_quant":
345-
add_observer_(q_model)
346-
q_model = convert(q_model, mapping=q_mapping, inplace=True)
338+
add_observer_(model)
339+
model = convert(model, mapping=q_mapping, inplace=True)
347340

348341
bf16_ops_list = tune_cfg['bf16_ops_list'] if 'bf16_ops_list' in tune_cfg.keys() else []
349342
if len(bf16_ops_list) > 0 and (version >= Version("1.11.0-rc1")):
350343
from ..adaptor.torch_utils.bf16_convert import Convert
351-
q_model = Convert(q_model, tune_cfg)
344+
model = Convert(model, tune_cfg)
352345
if checkpoint_dir is None and history_cfg is not None:
353-
_set_activation_scale_zeropoint(q_model, history_cfg)
346+
_set_activation_scale_zeropoint(model, history_cfg)
354347
else:
355-
q_model.load_state_dict(stat_dict)
356-
util.get_embedding_contiguous(q_model)
357-
return q_model
348+
model.load_state_dict(stat_dict)
349+
util.get_embedding_contiguous(model)
350+
return model

test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ def build_pytorch_yaml():
264264

265265
def build_pytorch_fx_yaml():
266266
if PT_VERSION >= Version("1.9.0").release:
267-
fake_fx_ptq_yaml = fake_ptq_yaml_for_fx
267+
fake_fx_ptq_yaml = fake_ptq_yaml_for_fx
268268
else:
269-
fake_fx_ptq_yaml = fake_ptq_yaml.replace('pytorch', 'pytorch_fx')
269+
fake_fx_ptq_yaml = fake_ptq_yaml.replace('pytorch', 'pytorch_fx')
270270
with open('fx_ptq_yaml.yaml', 'w', encoding="utf-8") as f:
271271
f.write(fake_fx_ptq_yaml)
272272

@@ -712,11 +712,11 @@ def test_tensor_dump_and_set(self):
712712
a = load_array('saved/dump_tensor/activation_iter1.npz')
713713
w = load_array('saved/dump_tensor/weight.npz')
714714
if PT_VERSION >= Version("1.8.0").release:
715-
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
716-
a['conv1.0'].item()['conv1.0.output0'].shape[1])
715+
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
716+
a['conv1.0'].item()['conv1.0.output0'].shape[1])
717717
else:
718-
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
719-
a['conv1.0'].item()['conv1.1.output0'].shape[1])
718+
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
719+
a['conv1.0'].item()['conv1.1.output0'].shape[1])
720720
data = np.random.random(w['conv1.0'].item()['conv1.0.weight'].shape).astype(np.float32)
721721
quantizer.strategy.adaptor.set_tensor(q_model, {'conv1.0.weight': data})
722722
changed_tensor = q_model.get_weight('conv1.weight')
@@ -789,7 +789,7 @@ def forward(self, x):
789789
q_capability = self.adaptor.query_fw_capability(model)
790790
for k, v in q_capability["opwise"].items():
791791
if k[0] != "quant" and k[0] != "dequant":
792-
fallback_ops.append(k[0])
792+
fallback_ops.append(k[0])
793793
model.model.qconfig = torch.quantization.default_qconfig
794794
model.model.quant.qconfig = torch.quantization.default_qconfig
795795
if PT_VERSION >= Version("1.8.0").release:
@@ -903,7 +903,7 @@ def test_fx_dynamic_quant(self):
903903
# run fx_quant in neural_compressor and save the quantized GraphModule
904904
model.eval()
905905
quantizer = Quantization('fx_dynamic_yaml.yaml')
906-
quantizer.model = common.Model(model,
906+
quantizer.model = common.Model(copy.deepcopy(model),
907907
**{'prepare_custom_config_dict': \
908908
{'non_traceable_module_name': ['a']},
909909
'convert_custom_config_dict': \
@@ -913,7 +913,7 @@ def test_fx_dynamic_quant(self):
913913
q_model.save('./saved')
914914

915915
# Load configure and weights by neural_compressor.utils
916-
model_fx = load("./saved", model,
916+
model_fx = load("./saved", copy.deepcopy(model),
917917
**{'prepare_custom_config_dict': \
918918
{'non_traceable_module_name': ['a']},
919919
'convert_custom_config_dict': \
@@ -929,7 +929,7 @@ def test_fx_dynamic_quant(self):
929929
yaml.dump(tune_cfg, f, default_flow_style=False)
930930
torch.save(state_dict, "./saved/best_model_weights.pt")
931931
os.remove('./saved/best_model.pt')
932-
model_fx = load("./saved", model,
932+
model_fx = load("./saved", copy.deepcopy(model),
933933
**{'prepare_custom_config_dict': \
934934
{'non_traceable_module_name': ['a']},
935935
'convert_custom_config_dict': \

test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def test_fx_quant(self):
322322
else:
323323
conf = PostTrainingQuantConfig(
324324
op_name_list=ptq_fx_op_name_list)
325+
conf.example_inputs = torch.randn([1, 3, 224, 224])
325326
set_workspace("./saved")
326327
q_model = quantization.fit(model_origin,
327328
conf,
@@ -381,11 +382,11 @@ def test_fx_dynamic_quant(self):
381382
origin_model.eval()
382383
conf = PostTrainingQuantConfig(approach="dynamic", op_name_list=ptq_fx_op_name_list)
383384
set_workspace("./saved")
384-
q_model = quantization.fit(origin_model, conf)
385+
q_model = quantization.fit(copy.deepcopy(origin_model), conf)
385386
q_model.save("./saved")
386387

387388
# Load configure and weights by neural_compressor.utils
388-
model_fx = load("./saved", origin_model)
389+
model_fx = load("./saved", copy.deepcopy(origin_model))
389390
self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule))
390391

391392
# Test the functionality of older model saving type
@@ -396,7 +397,7 @@ def test_fx_dynamic_quant(self):
396397
yaml.dump(tune_cfg, f, default_flow_style=False)
397398
torch.save(state_dict, "./saved/best_model_weights.pt")
398399
os.remove("./saved/best_model.pt")
399-
model_fx = load("./saved", origin_model)
400+
model_fx = load("./saved", copy.deepcopy(origin_model))
400401
self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule))
401402

402403
# recover int8 model with only tune_cfg
@@ -472,8 +473,7 @@ def test_fx_sub_module_quant(self):
472473
# recover int8 model with only tune_cfg
473474
history_file = "./saved/history.snapshot"
474475
model_fx_recover = recover(model_origin, history_file, 0,
475-
**{"dataloader": torch.utils.data.DataLoader(dataset)
476-
})
476+
**{"dataloader": torch.utils.data.DataLoader(dataset)})
477477
self.assertEqual(model_fx.sub.code, model_fx_recover.sub.code)
478478
shutil.rmtree("./saved", ignore_errors=True)
479479

@@ -489,7 +489,7 @@ def test_mix_precision(self):
489489
q_model = quantization.fit(model_origin,
490490
conf,
491491
calib_dataloader=dataloader,
492-
calib_func = eval_func)
492+
calib_func=eval_func)
493493
tune_cfg = q_model.q_config
494494
tune_cfg["op"][("conv.module", "Conv2d")].clear()
495495
tune_cfg["op"][("conv.module", "Conv2d")] = \

0 commit comments

Comments
 (0)