@@ -60,7 +60,10 @@ def forward_wrapper(model, input, device=torch.device("cpu")):
6060 if isinstance (input , dict ) or isinstance (input , UserDict ):
6161 output = model (** input )
6262 elif isinstance (input , list ) or isinstance (input , tuple ):
63- output = model (* input )
63+ try :
64+ output = model (* input )
65+ except :
66+ output = model (input )
6467 else :
6568 output = model (input )
6669 return output
@@ -295,8 +298,6 @@ def __init__(self, model, dataloader, example_inputs=None, q_func=None, traced_m
295298 self .dataloader = dataloader
296299 self .example_inputs = example_inputs
297300 self .q_func = q_func
298- self .input_values = {}
299- self .output_values = {}
300301 self .input_maxes = {}
301302 self .input_mins = {}
302303 self .input_maxes_abs = {}
@@ -325,10 +326,6 @@ def _save_input_pc_hook(self, name, percentile=100):
325326 :return: A hook function."""
326327
327328 def save_input_hook (module , inputs , outputs ):
328- if name not in self .input_maxes .keys ():
329- self .input_maxes [name ] = []
330- self .input_mins [name ] = []
331- self .input_maxes_abs [name ] = []
332329 input = inputs [0 ]
333330 ##TODO check input channel is correct
334331 if len (module .weight .shape ) == 4 : ##conv3d or conv1d not supported now, need better way
@@ -339,43 +336,16 @@ def save_input_hook(module, inputs, outputs):
339336 k_index = int (input .shape [0 ] * percentile / 100 )
340337 res , _ = torch .kthvalue (torch .abs (input ), k_index , dim = 0 )
341338 ##res = torch.max(torch.abs(input),dim=0)[0]
342- self .input_maxes_abs [name ].append (res )
343- self .input_maxes [name ].append (max_tensor )
344- self .input_mins [name ].append (min_tensor )
339+ if name not in self .input_maxes .keys ():
340+ self .input_mins [name ], self .input_maxes [name ] = min_tensor , max_tensor
341+ self .input_maxes_abs [name ] = res
342+ else :
343+ self .input_mins [name ] = torch .min (self .input_mins [name ], min_tensor )
344+ self .input_maxes [name ] = torch .max (self .input_maxes [name ], max_tensor )
345+ self .input_maxes_abs [name ] = torch .max (self .input_maxes_abs [name ], res )
345346
346347 return save_input_hook
347348
348- def _save_input_output_hook (self , name ):
349- """
350- A forward hook to save input and output values of a module
351- param name: the module name
352- return: A hook function
353- """
354-
355- def save_input_output_hook (module , inputs , outputs ):
356- input = inputs [0 ]
357- cnt = 32
358- if name in self .input_values .keys () and len (self .input_values [name ]) < cnt :
359- self .input_values [name ].append (input )
360- self .output_values [name ].append (outputs )
361- if name not in self .input_values .keys ():
362- self .input_values [name ] = [input ] ##TODO save more,like 8
363- self .output_values [name ] = [outputs ] ##TODO do not save output
364-
365- return save_input_output_hook
366-
367- def _add_input_output_observer (self ):
368- input_output_modules = {}
369- hook_layer_names = []
370- for key in self .absorb_to_layer :
371- hook_layer_names += self .absorb_to_layer [key ]
372- for name in hook_layer_names :
373- input_output_modules [name ] = get_module (self .model , name )
374- for key in input_output_modules .keys ():
375- hook_func = self ._save_input_output_hook (key )
376- hook_handle = input_output_modules [key ].register_forward_hook (hook_func )
377- self .hook_handles .append (hook_handle )
378-
379349 def _add_min_max_observer (self , modules , percentile = 100 ):
380350 """
381351 :param modules: the modules which the observer will insert to
@@ -393,7 +363,7 @@ def _remove_observer(self):
393363 for hook_handle in self .hook_handles :
394364 hook_handle .remove ()
395365
396- def _calibrate (self , absorb_to_layer , calib_iter , percentile , save_input_output = False ):
366+ def _calibrate (self , absorb_to_layer , calib_iter , percentile ):
397367 """
398368 :param absorb_to_layer: A dict,key is the absorb layer, val is a list of the to be smoothed layer
399369 :param calib_iter: Data size for calibration
@@ -406,8 +376,6 @@ def _calibrate(self, absorb_to_layer, calib_iter, percentile, save_input_output=
406376 hook_modules [n ] = module
407377
408378 self ._add_min_max_observer (hook_modules , percentile )
409- if save_input_output :
410- self ._add_input_output_observer ()
411379
412380 self ._dump_min_max (calib_iter = calib_iter )
413381 self ._remove_observer ()
@@ -423,15 +391,6 @@ def _dump_min_max(self, calib_iter=100):
423391 else :
424392 assert self .dataloader , "Please set dataloader for calibration."
425393 model_forward (self .model , self .dataloader , calib_iter , self .device )
426- ##stack
427- for key in self .input_maxes .keys ():
428- max_val = self .input_maxes [key ]
429- max_val = torch .stack (max_val , dim = 0 )
430- min_val = self .input_mins [key ]
431- min_val = torch .stack (min_val , dim = 0 )
432- self .input_maxes [key ] = torch .max (max_val , dim = 0 )[0 ]
433- self .input_mins [key ] = torch .min (min_val , dim = 0 )[0 ]
434- self .input_maxes_abs [key ] = torch .max (torch .stack (self .input_maxes_abs [key ], dim = 0 ), dim = 0 )[0 ]
435394
436395 def _reshape_in_channel_to_last (self , layer_name ):
437396 """Move the input channel to the last dim
@@ -877,43 +836,76 @@ def _auto_tune_alpha_new(
877836 if not self .dataloader :
878837 self ._qdq_model_unwrapper_for_auto ()
879838 return best_alphas
880-
881- for idx , input in enumerate (self .dataloader ):
882- if isinstance (input , (tuple , list )):
883- input = input [0 ]
884- best_alphas_per_module = best_alphas
885- if isinstance (best_alphas , dict ):
886- for key in self .absorb_to_layer .keys ():
887- layer_names = self .absorb_to_layer [key ]
888- for layer_name in layer_names :
889- best_alphas_per_module [layer_name ] = best_alphas_per_module [key ]
890-
891- loss_tmp = self ._get_one_sample_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
892- if loss_alphas == {}:
893- loss_alphas = loss_tmp
894- else :
895- for key in loss_alphas .keys ():
896- cur_loss = loss_alphas [key ]
897- for alpha_key in cur_loss .keys ():
898- cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
899- if isinstance (input , list ):
900- input = move_input_to_device (input , self .device )
901- for inp in input :
902- cnt += inp .shape [0 ]
903- else :
904- cnt += input .shape [0 ]
905-
906- if cnt % multiply_factor == 0 and (auto_calib_iter - cnt ) >= multiply_factor :
907- best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
908- for key in best_alphas .keys ():
909- logger .info (f"{ cnt // multiply_factor } ,{ key } :{ best_alphas [key ]} " )
910- absorb_input_scales , weight_scales = self ._cal_scales (
911- self .absorb_to_layer , input_maxes , best_alphas , tuning = True
912- )
913- self ._update_scales_for_auto (absorb_input_scales , weight_scales )
914- loss_alphas = {} ##TODO check need to remove this one
915- if cnt >= auto_calib_iter :
916- break
839+ try :
840+ for input , label in self .dataloader :
841+ best_alphas_per_module = best_alphas
842+ if isinstance (best_alphas , dict ):
843+ for key in self .absorb_to_layer .keys ():
844+ layer_names = self .absorb_to_layer [key ]
845+ for layer_name in layer_names :
846+ best_alphas_per_module [layer_name ] = best_alphas_per_module [key ]
847+
848+ loss_tmp = self ._get_one_sample_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
849+ if loss_alphas == {}:
850+ loss_alphas = loss_tmp
851+ else :
852+ for key in loss_alphas .keys ():
853+ cur_loss = loss_alphas [key ]
854+ for alpha_key in cur_loss .keys ():
855+ cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
856+ if isinstance (input , list ):
857+ input = move_input_to_device (input , self .device )
858+ for inp in input :
859+ cnt += inp .shape [0 ]
860+ else :
861+ cnt += input .shape [0 ]
862+
863+ if cnt % multiply_factor == 0 and (auto_calib_iter - cnt ) >= multiply_factor :
864+ best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
865+ for key in best_alphas .keys ():
866+ logger .info (f"{ cnt // multiply_factor } ,{ key } :{ best_alphas [key ]} " )
867+ absorb_input_scales , weight_scales = self ._cal_scales (
868+ self .absorb_to_layer , input_maxes , best_alphas , tuning = True
869+ )
870+ self ._update_scales_for_auto (absorb_input_scales , weight_scales )
871+ loss_alphas = {} ##TODO check need to remove this one
872+ if cnt >= auto_calib_iter :
873+ break
874+ except :
875+ for input in self .dataloader :
876+ best_alphas_per_module = best_alphas
877+ if isinstance (best_alphas , dict ):
878+ for key in self .absorb_to_layer .keys ():
879+ layer_names = self .absorb_to_layer [key ]
880+ for layer_name in layer_names :
881+ best_alphas_per_module [layer_name ] = best_alphas_per_module [key ]
882+
883+ loss_tmp = self ._get_one_sample_auto_loss (input , alpha_space , best_alphas_per_module , input_maxes )
884+ if loss_alphas == {}:
885+ loss_alphas = loss_tmp
886+ else :
887+ for key in loss_alphas .keys ():
888+ cur_loss = loss_alphas [key ]
889+ for alpha_key in cur_loss .keys ():
890+ cur_loss [alpha_key ] += loss_tmp [key ][alpha_key ]
891+ if isinstance (input , list ):
892+ input = move_input_to_device (input , self .device )
893+ for inp in input :
894+ cnt += inp .shape [0 ]
895+ else :
896+ cnt += input .shape [0 ]
897+
898+ if cnt % multiply_factor == 0 and (auto_calib_iter - cnt ) >= multiply_factor :
899+ best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
900+ for key in best_alphas .keys ():
901+ logger .info (f"{ cnt // multiply_factor } ,{ key } :{ best_alphas [key ]} " )
902+ absorb_input_scales , weight_scales = self ._cal_scales (
903+ self .absorb_to_layer , input_maxes , best_alphas , tuning = True
904+ )
905+ self ._update_scales_for_auto (absorb_input_scales , weight_scales )
906+ loss_alphas = {} ##TODO check need to remove this one
907+ if cnt >= auto_calib_iter :
908+ break
917909
918910 best_alphas = self ._get_best_alpha (self .absorb_to_layer , loss_alphas , shared_criterion )
919911 for key in best_alphas .keys ():
@@ -995,11 +987,8 @@ def transform(
995987 "you could set torchscript to True "
996988 )
997989 return self .model
998- save_input_output = False if alpha == "auto" else True
999- # if alpha == "auto":
1000- # save_input_output = True
1001990
1002- input_maxes_abs = self ._calibrate (self .absorb_to_layer , calib_iter , percentile , save_input_output )
991+ input_maxes_abs = self ._calibrate (self .absorb_to_layer , calib_iter , percentile )
1003992
1004993 # Check if input_maxes match self.absorb_to_layer
1005994 # (due to self._get_all_layer_names use layer tree instead of forward_path)
@@ -1042,7 +1031,6 @@ def transform(
10421031 else :
10431032 logger .warning (" Could not get example input, equivelancy check is skipped" )
10441033
1045- self .input_values , self .output_values = {}, {}
10461034 return self .model
10471035
10481036 def output_is_equal (self , out1 , out2 , atol = 1e-04 ):
0 commit comments