@@ -1632,25 +1632,36 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16321632 Returns:
16331633 (dict): quantized model
16341634 """
1635+ if self .performance_only :
1636+ tmp_model = model
1637+ else :
1638+ try :
1639+ tmp_model = copy .deepcopy (model )
1640+ except Exception as e : # pragma: no cover
1641+ logger .warning ("Fail to deep copy the model due to {}, inplace is used now." .format (repr (e )))
1642+ tmp_model = model
1643+
16351644 assert q_func is None , "quantization aware training has not been supported on ONNXRUNTIME"
16361645 for precision in self .query_handler .get_precisions ():
16371646 if precision == "weight_only_integer" :
16381647 self .quantizable_op_types += self .query_handler .get_op_types_by_precision (precision = precision )
1639- self .quantizable_ops = self ._query_quantizable_ops (model .model )
1648+ self .quantizable_ops = self ._query_quantizable_ops (tmp_model .model )
16401649
1650+ self ._update_tune_cfg (tune_cfg , tmp_model .model )
16411651 quant_config = self ._cfg_to_quantize_config (tune_cfg )
16421652 algos = set ([item ["algorithm" ] for key , item in quant_config .items () if isinstance (item , dict )])
16431653 if "GPTQ" in algos :
16441654 from neural_compressor .adaptor .ox_utils .weight_only import gptq_quantize
16451655
1656+ assert data_loader is not None , "GPTQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
16461657 percdamp = self .recipes .get ("gptq_args" , {}).get ("percdamp" , 0.01 )
16471658 blocksize = self .recipes .get ("gptq_args" , {}).get ("blocksize" , 128 )
16481659 actorder = self .recipes .get ("gptq_args" , {}).get ("actorder" , False )
16491660 mse = self .recipes .get ("gptq_args" , {}).get ("mse" , False )
16501661 perchannel = self .recipes .get ("gptq_args" , {}).get ("perchannel" , True )
16511662 calib_sampling_size = tune_cfg .get ("calib_sampling_size" , 1 )
1652- model = gptq_quantize (
1653- model ,
1663+ tmp_model = gptq_quantize (
1664+ tmp_model ,
16541665 data_loader ,
16551666 quant_config ,
16561667 n_samples = calib_sampling_size ,
@@ -1663,11 +1674,12 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16631674 if "AWQ" in algos :
16641675 from neural_compressor .adaptor .ox_utils .weight_only import awq_quantize
16651676
1677+ assert data_loader is not None , "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
16661678 enable_auto_scale = self .recipes .get ("awq_args" , {}).get ("enable_auto_scale" , True )
16671679 enable_mse_search = self .recipes .get ("awq_args" , {}).get ("enable_mse_search" , True )
16681680 calib_sampling_size = tune_cfg .get ("calib_sampling_size" , 1 )
1669- model = awq_quantize (
1670- model ,
1681+ tmp_model = awq_quantize (
1682+ tmp_model ,
16711683 data_loader ,
16721684 quant_config ,
16731685 n_samples = calib_sampling_size ,
@@ -1677,11 +1689,11 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16771689 elif "RTN" in algos :
16781690 from neural_compressor .adaptor .ox_utils .weight_only import rtn_quantize
16791691
1680- model = rtn_quantize (model , quant_config )
1681- model .q_config = copy .deepcopy (quant_config )
1682- self ._dump_model_op_stats (model , tune_cfg )
1683- model .topological_sort ()
1684- return model
1692+ tmp_model = rtn_quantize (tmp_model , quant_config )
1693+ tmp_model .q_config = copy .deepcopy (quant_config )
1694+ self ._dump_model_op_stats (tmp_model , tune_cfg )
1695+ tmp_model .topological_sort ()
1696+ return tmp_model
16851697
16861698 def _dump_model_op_stats (self , model , tune_cfg ):
16871699 import re
@@ -1747,6 +1759,31 @@ def _cfg_to_quantize_config(self, tune_cfg):
17471759
17481760 return quantize_config
17491761
1762+ def _update_tune_cfg (self , tune_cfg , model ):
1763+ """Update tune cfg according to woq_tuning_cfg."""
1764+ if tune_cfg .get ("woq_tuning_cfg" ) is None :
1765+ return tune_cfg
1766+
1767+ from neural_compressor .strategy .utils .constant import WOQ_TUNING_ALGOS
1768+
1769+ woq_tuning_cfg = tune_cfg .get ("woq_tuning_cfg" )
1770+ new_woq_cfg = WOQ_TUNING_ALGOS .get (woq_tuning_cfg )
1771+
1772+ for node_cfg in tune_cfg ["op" ].values ():
1773+ node_cfg ["weight" ].update (
1774+ {cfg_name : cfg_value for cfg_name , cfg_value in new_woq_cfg .items () if cfg_name in node_cfg ["weight" ]}
1775+ )
1776+
1777+ # find last matmul and set to fp32
1778+ if "DISABLE_LAST_MATMUL" in woq_tuning_cfg :
1779+ last_matmul = None
1780+ fp32_op_cfg = {"weight" : {"dtype" : "fp32" }, "activation" : {"dtype" : "fp32" , "quant_mode" : "fp32" }}
1781+ for node in model .graph .node :
1782+ if node .op_type in ["MatMul" ]:
1783+ last_matmul = (node .name , node .op_type )
1784+ if last_matmul in tune_cfg ["op" ]:
1785+ tune_cfg ["op" ][last_matmul ].update (fp32_op_cfg )
1786+
17501787 def query_fw_capability (self , model ):
17511788 """The function is used to query framework capability.
17521789 TODO: will be replaced by framework query API
0 commit comments