@@ -43,8 +43,9 @@ def distributed_next_tune_cfg_lst(self, comm):
4343 """Generate and yield the next tuning config list with below order.
4444
4545 1. OP Type Wise Tuning
46- 2. Fallback OP One by One
47- 3. Fallback Multiple OPs Accumulated
46+ 2. Fallback OPs Block by Block
47+ 3. Fallback OP One by One
48+ 4. Fallback Multiple OPs Accumulated
4849
4950 Yields:
5051 tuning_config_list (list): A list containing dicts of the tuning configuration for quantization.
@@ -62,6 +63,18 @@ def distributed_next_tune_cfg_lst(self, comm):
6263 quant_ops = quant_mode_wise_items ['static' ] if 'static' in quant_mode_wise_items else []
6364 quant_ops += quant_mode_wise_items ['dynamic' ] if 'dynamic' in quant_mode_wise_items else []
6465 stage1_max = 1e9 # TODO set a more appropriate value
66+ if not self .cur_best_tuning_cfg :
67+ self .cur_best_tuning_cfg = deepcopy (initial_op_tuning_cfg )
68+
69+ # try to tune sq alpha
70+ op_tuning_cfg_lst_stage_sq = []
71+ if self ._should_tuning_sq_alpha (self .config .recipes ):
72+ for tune_cfg in self .tuning_sq_alpha (tuning_space , \
73+ deepcopy (self .cur_best_tuning_cfg ), self .config .recipes ):
74+ op_tuning_cfg_lst_stage_sq .append (tune_cfg )
75+ yield op_tuning_cfg_lst_stage_sq
76+
77+ # op type-wise tuning
6578 op_type_wise_tuning_sampler = OpTypeWiseTuningSampler (tuning_space , [], [],
6679 op_item_dtype_dict , initial_op_tuning_cfg )
6780 # stage 1: yield op_tune_cfg_lst
@@ -83,6 +96,7 @@ def distributed_next_tune_cfg_lst(self, comm):
8396 else :
8497 self .cur_best_tuning_cfg = comm .bcast (cur_best_tuning_cfg , root = 0 )
8598
99+
86100 # stage 2: yield new_op_tuning_cfg_lst (length of stage 1)
87101 # Fallback the ops supported both static and dynamic from static to dynamic
88102 # Tuning items: None
@@ -113,12 +127,25 @@ def distributed_next_tune_cfg_lst(self, comm):
113127 best_op_tuning_cfg_stage1 = deepcopy (self .cur_best_tuning_cfg )
114128
115129 # Fallback
130+ # Fallback block after stage (1, 2) and before stage (3, 4)
116131 # stage 3, 4: yield op_tuning_cfg_lst
132+ op_tuning_cfg_lst_stage_block = []
117133 op_tuning_cfg_lst_stage_3 = []
118134 op_tuning_cfg_lst_stage_4 = []
119- for target_dtype in [ 'bf16' , 'fp32' ] :
135+ for target_dtype in PRECISION_LIST :
120136 target_type_lst = set (tuning_space .query_items_by_quant_mode (target_dtype ))
121137 fallback_items_lst = [item for item in quant_ops if item in target_type_lst ]
138+
139+ # Fallback block by block
140+ for op_tuning_cfg in self .fallback_by_block (fallback_items_lst , best_op_tuning_cfg_stage1 ,
141+ target_dtype ,
142+ tuning_space ,
143+ calib_sampling_size ):
144+ op_tuning_cfg_lst_stage_block .append (deepcopy (op_tuning_cfg ))
145+ logger .info ("yield op_tuning_cfg_lst_stage_block with length {}" \
146+ .format (len (op_tuning_cfg_lst_stage_block )))
147+ yield op_tuning_cfg_lst_stage_block
148+
122149 if fallback_items_lst :
123150 logger .info (f"Start to fallback op to { target_dtype } one by one." )
124151 self ._fallback_started ()
@@ -273,8 +300,6 @@ def next_tune_cfg(self):
273300 op_item_dtype_dict , initial_op_tuning_cfg )
274301
275302 for index , op_tuning_cfg in enumerate (op_type_wise_tuning_sampler ):
276- if not self .cur_best_tuning_cfg :
277- self .cur_best_tuning_cfg = deepcopy (initial_op_tuning_cfg )
278303 op_tuning_cfg ['calib_sampling_size' ] = calib_sampling_size
279304 # try to quantizing ops into lower bits, such as int4,
280305 # if accuracy meets the requirements after first trial and max_trials > 1
0 commit comments