@@ -111,7 +111,7 @@ class BaseCompressor(object):
111111 sym (bool): Whether to use symmetric weight quantization.
112112 layer_config (dict): Per-layer quantization configuration.
113113 nsamples (int): Number of calibration samples.
114- enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
114+ enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers.
115115 """
116116
117117 bits : int | None
@@ -361,6 +361,7 @@ def __init__(
361361 self .infer_bs_coeff = 1
362362 self .enable_torch_compile = enable_torch_compile
363363 self ._adjust_torch_compile (enable_torch_compile )
364+ self .block_forward = compile_func (block_forward , self .device ) if self .enable_torch_compile else block_forward
364365 self ._check_configs ()
365366 torch .set_printoptions (precision = 3 , sci_mode = True )
366367
@@ -1428,6 +1429,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14281429 enable_minmax_tuning = False ,
14291430 enable_norm_bias_tuning = False ,
14301431 enable_round_tuning = False ,
1432+ enable_torch_compile = self .enable_torch_compile ,
14311433 )
14321434 m = m .unwrapper ({})
14331435 m .to ("cpu" )
@@ -1443,6 +1445,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14431445 enable_minmax_tuning = False ,
14441446 enable_norm_bias_tuning = False ,
14451447 enable_round_tuning = False ,
1448+ enable_torch_compile = self .enable_torch_compile ,
14461449 )
14471450 m = m .unwrapper ({})
14481451 except Exception as e :
@@ -1882,6 +1885,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18821885 enable_round_tuning = False ,
18831886 enable_minmax_tuning = False ,
18841887 enable_norm_bias_tuning = False ,
1888+ enable_torch_compile = self .enable_torch_compile ,
18851889 device = self .device ,
18861890 )
18871891 new_layer = wrapper_layer .unwrapper ({})
@@ -1911,10 +1915,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
19111915
19121916 self .model = mv_module_from_gpu (self .model , self .low_cpu_mem_usage )
19131917 clear_memory ()
1914- if self .enable_torch_compile :
1915- quant_layer = compile_func (self ._quantize_layer , self .device )
1916- else :
1917- quant_layer = self ._quantize_layer
1918+ quant_layer = self ._quantize_layer
19181919 for layer_name in layer_names :
19191920 layer_input = layer_inputs [layer_name ]
19201921 layer_input = to_device (layer_input , self .cache_device )
@@ -2093,9 +2094,9 @@ def _get_block_outputs(
20932094 tmp_input_ids , tmp_input_others = self ._sampling_inputs (
20942095 input_ids , input_others , indices , self .seqlen , self .batch_dim , share_cache_keys = self .shared_cache_keys
20952096 )
2096- tmp_output = block_forward ( block , tmp_input_ids , tmp_input_others , self .amp , self . amp_dtype , device ). to (
2097- cache_device
2098- )
2097+ tmp_output = self .block_forward (
2098+ block , tmp_input_ids , tmp_input_others , self . amp , self . amp_dtype , device
2099+ ). to ( cache_device )
20992100 if save_output :
21002101 if self .batch_size == 1 :
21012102 output .append (tmp_output )
@@ -2518,7 +2519,12 @@ def _quantize_layer(
25182519 if q_inputs is not None :
25192520 q_inputs [i ] = q_inputs [i ].to (layer .weight .dtype )
25202521
2521- wrapper_linear = WrapperLinear (layer , enable_minmax_tuning = self .enable_minmax_tuning , device = device ).to (device )
2522+ wrapper_linear = WrapperLinear (
2523+ layer ,
2524+ enable_minmax_tuning = self .enable_minmax_tuning ,
2525+ enable_torch_compile = self .enable_torch_compile ,
2526+ device = device ,
2527+ ).to (device )
25222528 round_params = []
25232529 minmax_params = []
25242530 for key in wrapper_linear .params .keys ():
@@ -2696,7 +2702,7 @@ def _get_current_q_output(
26962702 batch_dim = self .batch_dim ,
26972703 share_cache_keys = self .shared_cache_keys ,
26982704 )
2699- output_q = block_forward (block , current_input_ids , current_input_others , self .amp , self .amp_dtype , device )
2705+ output_q = self . block_forward (block , current_input_ids , current_input_others , self .amp , self .amp_dtype , device )
27002706 return output_q
27012707
27022708 def _get_current_num_elm (
@@ -2781,7 +2787,11 @@ def _quantize_block(
27812787 input_ids = q_input
27822788
27832789 quantized_layer_names , unquantized_layer_names = wrapper_block (
2784- block , self .enable_minmax_tuning , self .enable_norm_bias_tuning , device = self .device
2790+ block ,
2791+ self .enable_minmax_tuning ,
2792+ self .enable_norm_bias_tuning ,
2793+ enable_torch_compile = self .enable_torch_compile ,
2794+ device = self .device ,
27852795 )
27862796 if is_nv_fp (self .data_type ): # enable qkv and moe structure global_scale fuse
27872797 from auto_round .data_type .utils import update_fused_layer_global_scales
@@ -3008,14 +3018,8 @@ def _quantize_blocks(
30083018 logger .info ("using algorithm extension for quantization." )
30093019 except (ImportError , ModuleNotFoundError ):
30103020 quantize_block = self ._quantize_block
3011- if self .enable_torch_compile :
3012- quantize_block = compile_func (quantize_block , device )
3013- else :
3014- quantize_block = quantize_block
30153021 else :
30163022 quantize_block = self ._quantize_block
3017- if self .enable_torch_compile :
3018- quantize_block = compile_func (quantize_block , device )
30193023
30203024 if pbar is None :
30213025 pbar = tqdm (range (0 , len (block_names ), nblocks ))
0 commit comments