Skip to content

Commit ea309f5

Browse files
authored
Add distributed fallback by blockwise (#1179)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 04884ed commit ea309f5

File tree

2 files changed

+35
-6
lines changed

2 files changed

+35
-6
lines changed

neural_compressor/strategy/basic.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def distributed_next_tune_cfg_lst(self, comm):
4343
"""Generate and yield the next tuning config list with below order.
4444
4545
1. OP Type Wise Tuning
46-
2. Fallback OP One by One
47-
3. Fallback Multiple OPs Accumulated
46+
2. Fallback OPs Block by Block
47+
3. Fallback OP One by One
48+
4. Fallback Multiple OPs Accumulated
4849
4950
Yields:
5051
tuning_config_list (list): A list containing dicts of the tuning configuration for quantization.
@@ -62,6 +63,18 @@ def distributed_next_tune_cfg_lst(self, comm):
6263
quant_ops = quant_mode_wise_items['static'] if 'static' in quant_mode_wise_items else []
6364
quant_ops += quant_mode_wise_items['dynamic'] if 'dynamic' in quant_mode_wise_items else []
6465
stage1_max = 1e9 # TODO set a more appropriate value
66+
if not self.cur_best_tuning_cfg:
67+
self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg)
68+
69+
# try to tune sq alpha
70+
op_tuning_cfg_lst_stage_sq = []
71+
if self._should_tuning_sq_alpha(self.config.recipes):
72+
for tune_cfg in self.tuning_sq_alpha(tuning_space, \
73+
deepcopy(self.cur_best_tuning_cfg), self.config.recipes):
74+
op_tuning_cfg_lst_stage_sq.append(tune_cfg)
75+
yield op_tuning_cfg_lst_stage_sq
76+
77+
# op type-wise tuning
6578
op_type_wise_tuning_sampler = OpTypeWiseTuningSampler(tuning_space, [], [],
6679
op_item_dtype_dict, initial_op_tuning_cfg)
6780
# stage 1: yield op_tune_cfg_lst
@@ -83,6 +96,7 @@ def distributed_next_tune_cfg_lst(self, comm):
8396
else:
8497
self.cur_best_tuning_cfg = comm.bcast(cur_best_tuning_cfg, root=0)
8598

99+
86100
# stage 2: yield new_op_tuning_cfg_lst (length of stage 1)
87101
# Fallback the ops supported both static and dynamic from static to dynamic
88102
# Tuning items: None
@@ -113,12 +127,25 @@ def distributed_next_tune_cfg_lst(self, comm):
113127
best_op_tuning_cfg_stage1 = deepcopy(self.cur_best_tuning_cfg)
114128

115129
# Fallback
130+
# Fallback block after stage (1, 2) and before stage (3, 4)
116131
# stage 3, 4: yield op_tuning_cfg_lst
132+
op_tuning_cfg_lst_stage_block = []
117133
op_tuning_cfg_lst_stage_3 = []
118134
op_tuning_cfg_lst_stage_4 = []
119-
for target_dtype in ['bf16', 'fp32']:
135+
for target_dtype in PRECISION_LIST:
120136
target_type_lst = set(tuning_space.query_items_by_quant_mode(target_dtype))
121137
fallback_items_lst = [item for item in quant_ops if item in target_type_lst]
138+
139+
# Fallback block by block
140+
for op_tuning_cfg in self.fallback_by_block(fallback_items_lst, best_op_tuning_cfg_stage1,
141+
target_dtype,
142+
tuning_space,
143+
calib_sampling_size):
144+
op_tuning_cfg_lst_stage_block.append(deepcopy(op_tuning_cfg))
145+
logger.info("yield op_tuning_cfg_lst_stage_block with length {}"\
146+
.format(len(op_tuning_cfg_lst_stage_block)))
147+
yield op_tuning_cfg_lst_stage_block
148+
122149
if fallback_items_lst:
123150
logger.info(f"Start to fallback op to {target_dtype} one by one.")
124151
self._fallback_started()
@@ -273,8 +300,6 @@ def next_tune_cfg(self):
273300
op_item_dtype_dict, initial_op_tuning_cfg)
274301

275302
for index, op_tuning_cfg in enumerate(op_type_wise_tuning_sampler):
276-
if not self.cur_best_tuning_cfg:
277-
self.cur_best_tuning_cfg = deepcopy(initial_op_tuning_cfg)
278303
op_tuning_cfg['calib_sampling_size'] = calib_sampling_size
279304
# try to quantizing ops into lower bits, such as int4,
280305
# if accuracy meets the requirements after first trial and max_trials > 1

neural_compressor/strategy/strategy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,13 +446,17 @@ def traverse(self):
446446
from mpi4py import MPI
447447
if MPI.COMM_WORLD.Get_size() > 2:
448448
logger.info("Use distributed tuning on {} nodes".format(MPI.COMM_WORLD.Get_size()))
449-
return self.distributed_traverse()
450449
elif MPI.COMM_WORLD.Get_size() == 2:
451450
logger.info("Use distributed tuning on {} nodes, will be fallback to normal tuning."\
452451
.format(MPI.COMM_WORLD.Get_size()))
452+
MPI_INSTALLED=True
453453
except (ImportError, AttributeError) as e:
454454
logger.warning("[Strategy] Please install `mpi4py` correctly if using distributed tuning;" + \
455455
" otherwise, ignore this warning.")
456+
MPI_INSTALLED=False
457+
if MPI_INSTALLED:
458+
if MPI.COMM_WORLD.Get_size() > 2:
459+
return self.distributed_traverse()
456460
self._setup_pre_tuning_algo_scheduler()
457461
self._prepare_tuning()
458462
# import pdb;pdb.set_trace()

0 commit comments

Comments
 (0)