1616import json
1717import os
1818import re
19+ from collections import OrderedDict
1920from typing import Dict , List , Union
2021
2122import torch
6667def cfg_to_qconfig (tune_cfg , cfgs , op_infos_from_cfgs , output_tensor_id_op_name ): # pragma: no cover
6768 assert cfgs is not None , "No configure for IPEX int8 model..."
6869 op_infos = copy .deepcopy (op_infos_from_cfgs )
69- cfgs = check_cfg_and_qconfig (tune_cfg ["op" ], cfgs , op_infos , output_tensor_id_op_name )
70+ cfgs , user_cfg = check_cfg_and_qconfig (tune_cfg ["op" ], cfgs , op_infos , output_tensor_id_op_name )
7071 with open (ipex_config_path , "w" ) as write_f :
7172 json .dump (cfgs , write_f , indent = 4 )
73+ return user_cfg
7274
7375
7476def check_cfg_and_qconfig (user_cfg , cfgs , op_infos_from_cfgs , output_tensor_ids_op_name ): # pragma: no cover
@@ -83,6 +85,15 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
8385 Returns:
8486 cfgs (dict): updated configs.
8587 """
88+ tmp_user_cfg = OrderedDict ()
89+ for op in user_cfg : # map ipex op_name to pt op_name
90+ for i , op_name in enumerate (op ):
91+ for ops , _ in op_infos_from_cfgs .items ():
92+ if "fqn" in op_infos_from_cfgs [ops ].keys () and op_infos_from_cfgs [ops ]["fqn" ] == op_name :
93+ ori_op = (tuple (ops ), unify_op_type_mapping_ipex [op_infos_from_cfgs [ops ]["op_type" ]])
94+ tmp_user_cfg [((ori_op [0 ],), ori_op [1 ])] = user_cfg [op ]
95+ break
96+ user_cfg = tmp_user_cfg
8697 for op_name in user_cfg :
8798 inc_op_cfg = user_cfg [op_name ]
8899 for i , name in enumerate (op_name [0 ]):
@@ -142,7 +153,7 @@ def check_cfg_and_qconfig(user_cfg, cfgs, op_infos_from_cfgs, output_tensor_ids_
142153 else :
143154 pass
144155 cfgs [name [0 ]][name [1 ]][name [2 ]] = ipex_op_cfg
145- return cfgs
156+ return cfgs , user_cfg
146157
147158
148159def generate_activation_observer (scheme , algorithm , smooth_quant = False , smooth_quant_enable = False ): # pragma: no cover
@@ -212,6 +223,7 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
212223 cfgs (dict): dict of configuration
213224 """
214225 quantizable_ops = []
226+ op_name_info = []
215227 # group ops by position for transform-based model
216228 detector = TransformerBasedModelBlockPatternDetector (model )
217229 detect_result = detector .detect_block ()
@@ -277,17 +289,30 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
277289 if ipex_op_type in unify_op_type_mapping_ipex :
278290 quantizable_ops .append ((tuple (name ), unify_op_type_mapping_ipex [ipex_op_type ]))
279291 map_op_name_to_fqn [(tuple (name ), ipex_op_type )] = module_fqn
292+ if "class" in ipex_op_type : # "<class 'torch.nn.modules.activation.ReLU'>"
293+ op_type = ipex_op_type .split ("'" )[1 ]
294+ op_name_info .append ((module_fqn , eval (op_type )))
295+ elif "method" in ipex_op_type : # "<method 'add' of 'torch._C._TensorBase' objects>"
296+ method = ipex_op_type .split ("'" )[1 ]
297+ op_type = getattr (
298+ torch ._C ._TensorBase if ipex_ver .release < Version ("2.2" ) else torch ._C .TensorBase , method
299+ )
300+ op_name_info .append ((module_fqn , op_type ))
301+ else :
302+ op_name_info .append ((module_fqn , op_type ))
280303 else :
281304 re_flag = False
282305 for pattern , unify_op_type in unify_op_type_mapping_ipex ["re" ].items ():
283306 if re .match (pattern , ipex_op_type ):
284307 re_flag = True
285308 quantizable_ops .append ((tuple (name ), unify_op_type ))
286309 map_op_name_to_fqn [(tuple (name ), unify_op_type )] = module_fqn
310+ op_name_info .append ((module_fqn , ipex_op_type ))
287311 break
288312 if not re_flag :
289313 quantizable_ops .append ((tuple (name ), ipex_op_type ))
290314 map_op_name_to_fqn [(tuple (name ), ipex_op_type )] = module_fqn
315+ op_name_info .append ((module_fqn , ipex_op_type ))
291316 else :
292317 op_type = ""
293318 for op_name in name :
@@ -302,14 +327,15 @@ def get_quantizable_ops_recursively(model, example_inputs): # pragma: no cover
302327 _op_cfg_id = name [0 ][2 ]
303328 module_fqn = cfgs [_module_key ]["q_op_infos" ][_op_cfg_id ]["fqn" ]
304329 map_op_name_to_fqn [(tuple (name ), op_type )] = module_fqn
330+ op_name_info .append ((module_fqn , op_type ))
305331
306332 logger .debug ("Map op name to fqn: " )
307333 logger .debug (map_op_name_to_fqn )
308334 logger .info ("Attention Blocks : " )
309335 logger .info (attention_block )
310336 logger .info ("FFN Blocks : " )
311337 logger .info (ffn_blocks )
312- return quantizable_ops , cfgs , op_infos_from_cfgs , output_tensor_id_op_name
338+ return quantizable_ops , cfgs , op_infos_from_cfgs , output_tensor_id_op_name , op_name_info
313339
314340
315341def simple_inference (q_model , example_inputs , iterations = 1 ):
@@ -323,16 +349,16 @@ def simple_inference(q_model, example_inputs, iterations=1):
323349 q_model (example_inputs )
324350
325351
326- def dump_model_op_stats (tune_cfg ):
352+ def dump_model_op_stats (user_cfg ):
327353 """This is a function to dump quantizable ops of model to user.
328354
329355 Args:
330- tune_cfg (dict): quantization config
356+ user_cfg (dict): quantization config
331357 Returns:
332358 None
333359 """
334360 res = dict ()
335- for k , v in tune_cfg [ "op" ] .items ():
361+ for k , v in user_cfg .items ():
336362 op_type_list = k [- 1 ].split ("><" )
337363 op_type = ""
338364 for op in op_type_list :
0 commit comments