@@ -111,7 +111,7 @@ class BaseCompressor(object):
111111 sym (bool): Whether to use symmetric weight quantization.
112112 layer_config (dict): Per-layer quantization configuration.
113113 nsamples (int): Number of calibration samples.
114- enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
114+ enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers.
115115 """
116116
117117 bits : int | None
@@ -173,7 +173,7 @@ def __init__(
173173 act_sym (bool, optional): Symmetric activation quantization. Defaults to None.
174174 act_data_type (str, optional): Activation data type; inherits weight dtype if None and act_bits < 16.
175175 act_dynamic (bool, optional): Dynamic activation quantization. Defaults to True.
176- enable_torch_compile (bool, optional): Enable torch.compile for quant blocks/layers. Defaults to False.
176+ enable_torch_compile (bool, optional): Enable compile_func for quant blocks/layers. Defaults to False.
177177 device_map (str | dict, optional): Device placement map. Defaults to None.
178178 disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to False.
179179 enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False.
@@ -361,6 +361,9 @@ def __init__(
361361 self .infer_bs_coeff = 1
362362 self .enable_torch_compile = enable_torch_compile
363363 self ._adjust_torch_compile (enable_torch_compile )
364+ self .block_forward = block_forward
365+ if self .enable_torch_compile :
366+ self .block_forward = compile_func (self .block_forward , self .device )
364367 self ._check_configs ()
365368 torch .set_printoptions (precision = 3 , sci_mode = True )
366369
@@ -1428,6 +1431,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14281431 enable_minmax_tuning = False ,
14291432 enable_norm_bias_tuning = False ,
14301433 enable_round_tuning = False ,
1434+ enable_torch_compile = self .enable_torch_compile ,
14311435 )
14321436 m = m .unwrapper ({})
14331437 m .to ("cpu" )
@@ -1443,6 +1447,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14431447 enable_minmax_tuning = False ,
14441448 enable_norm_bias_tuning = False ,
14451449 enable_round_tuning = False ,
1450+ enable_torch_compile = self .enable_torch_compile ,
14461451 )
14471452 m = m .unwrapper ({})
14481453 except Exception as e :
@@ -1882,6 +1887,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18821887 enable_round_tuning = False ,
18831888 enable_minmax_tuning = False ,
18841889 enable_norm_bias_tuning = False ,
1890+ enable_torch_compile = self .enable_torch_compile ,
18851891 device = self .device ,
18861892 )
18871893 new_layer = wrapper_layer .unwrapper ({})
@@ -1911,10 +1917,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
19111917
19121918 self .model = mv_module_from_gpu (self .model , self .low_cpu_mem_usage )
19131919 clear_memory ()
1914- if self .enable_torch_compile :
1915- quant_layer = compile_func (self ._quantize_layer , self .device )
1916- else :
1917- quant_layer = self ._quantize_layer
1920+ quant_layer = self ._quantize_layer
19181921 for layer_name in layer_names :
19191922 layer_input = layer_inputs [layer_name ]
19201923 layer_input = to_device (layer_input , self .cache_device )
@@ -2093,9 +2096,9 @@ def _get_block_outputs(
20932096 tmp_input_ids , tmp_input_others = self ._sampling_inputs (
20942097 input_ids , input_others , indices , self .seqlen , self .batch_dim , share_cache_keys = self .shared_cache_keys
20952098 )
2096- tmp_output = block_forward ( block , tmp_input_ids , tmp_input_others , self .amp , self . amp_dtype , device ). to (
2097- cache_device
2098- )
2099+ tmp_output = self .block_forward (
2100+ block , tmp_input_ids , tmp_input_others , self . amp , self . amp_dtype , device
2101+ ). to ( cache_device )
20992102 if save_output :
21002103 if self .batch_size == 1 :
21012104 output .append (tmp_output )
@@ -2518,7 +2521,12 @@ def _quantize_layer(
25182521 if q_inputs is not None :
25192522 q_inputs [i ] = q_inputs [i ].to (layer .weight .dtype )
25202523
2521- wrapper_linear = WrapperLinear (layer , enable_minmax_tuning = self .enable_minmax_tuning , device = device ).to (device )
2524+ wrapper_linear = WrapperLinear (
2525+ layer ,
2526+ enable_minmax_tuning = self .enable_minmax_tuning ,
2527+ enable_torch_compile = self .enable_torch_compile ,
2528+ device = device ,
2529+ ).to (device )
25222530 round_params = []
25232531 minmax_params = []
25242532 for key in wrapper_linear .params .keys ():
@@ -2696,7 +2704,7 @@ def _get_current_q_output(
26962704 batch_dim = self .batch_dim ,
26972705 share_cache_keys = self .shared_cache_keys ,
26982706 )
2699- output_q = block_forward (block , current_input_ids , current_input_others , self .amp , self .amp_dtype , device )
2707+ output_q = self . block_forward (block , current_input_ids , current_input_others , self .amp , self .amp_dtype , device )
27002708 return output_q
27012709
27022710 def _get_current_num_elm (
@@ -2781,7 +2789,11 @@ def _quantize_block(
27812789 input_ids = q_input
27822790
27832791 quantized_layer_names , unquantized_layer_names = wrapper_block (
2784- block , self .enable_minmax_tuning , self .enable_norm_bias_tuning , device = self .device
2792+ block ,
2793+ self .enable_minmax_tuning ,
2794+ self .enable_norm_bias_tuning ,
2795+ enable_torch_compile = self .enable_torch_compile ,
2796+ device = self .device ,
27852797 )
27862798 if is_nv_fp (self .data_type ): # enable qkv and moe structure global_scale fuse
27872799 from auto_round .data_type .utils import update_fused_layer_global_scales
@@ -3008,14 +3020,8 @@ def _quantize_blocks(
30083020 logger .info ("using algorithm extension for quantization." )
30093021 except (ImportError , ModuleNotFoundError ):
30103022 quantize_block = self ._quantize_block
3011- if self .enable_torch_compile :
3012- quantize_block = compile_func (quantize_block , device )
3013- else :
3014- quantize_block = quantize_block
30153023 else :
30163024 quantize_block = self ._quantize_block
3017- if self .enable_torch_compile :
3018- quantize_block = compile_func (quantize_block , device )
30193025
30203026 if pbar is None :
30213027 pbar = tqdm (range (0 , len (block_names ), nblocks ))
0 commit comments