Skip to content

Commit ae2f425

Browse files
committed
rewrite the implementation for ease-of-maintain
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent f2b9aef commit ae2f425

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

auto_round/compressors/base.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class BaseCompressor(object):
111111
sym (bool): Whether to use symmetric weight quantization.
112112
layer_config (dict): Per-layer quantization configuration.
113113
nsamples (int): Number of calibration samples.
114-
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
114+
enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers.
115115
"""
116116

117117
bits: int | None
@@ -173,7 +173,7 @@ def __init__(
173173
act_sym (bool, optional): Symmetric activation quantization. Defaults to None.
174174
act_data_type (str, optional): Activation data type; inherits weight dtype if None and act_bits < 16.
175175
act_dynamic (bool, optional): Dynamic activation quantization. Defaults to True.
176-
enable_torch_compile (bool, optional): Enable torch.compile for quant blocks/layers. Defaults to False.
176+
enable_torch_compile (bool, optional): Enable compile_func for quant blocks/layers. Defaults to False.
177177
device_map (str | dict, optional): Device placement map. Defaults to None.
178178
disable_opt_rtn (bool, optional): Disable RTN-mode optimization (iters=0). Defaults to False.
179179
enable_alg_ext (bool, optional): Enable algorithm extension (primarily for INT2). Defaults to False.
@@ -361,6 +361,9 @@ def __init__(
361361
self.infer_bs_coeff = 1
362362
self.enable_torch_compile = enable_torch_compile
363363
self._adjust_torch_compile(enable_torch_compile)
364+
self.block_forward = block_forward
365+
if self.enable_torch_compile:
366+
self.block_forward = compile_func(self.block_forward, self.device)
364367
self._check_configs()
365368
torch.set_printoptions(precision=3, sci_mode=True)
366369

@@ -1428,6 +1431,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14281431
enable_minmax_tuning=False,
14291432
enable_norm_bias_tuning=False,
14301433
enable_round_tuning=False,
1434+
enable_torch_compile=self.enable_torch_compile,
14311435
)
14321436
m = m.unwrapper({})
14331437
m.to("cpu")
@@ -1443,6 +1447,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14431447
enable_minmax_tuning=False,
14441448
enable_norm_bias_tuning=False,
14451449
enable_round_tuning=False,
1450+
enable_torch_compile=self.enable_torch_compile,
14461451
)
14471452
m = m.unwrapper({})
14481453
except Exception as e:
@@ -1882,6 +1887,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18821887
enable_round_tuning=False,
18831888
enable_minmax_tuning=False,
18841889
enable_norm_bias_tuning=False,
1890+
enable_torch_compile=self.enable_torch_compile,
18851891
device=self.device,
18861892
)
18871893
new_layer = wrapper_layer.unwrapper({})
@@ -1911,10 +1917,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
19111917

19121918
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
19131919
clear_memory()
1914-
if self.enable_torch_compile:
1915-
quant_layer = compile_func(self._quantize_layer, self.device)
1916-
else:
1917-
quant_layer = self._quantize_layer
1920+
quant_layer = self._quantize_layer
19181921
for layer_name in layer_names:
19191922
layer_input = layer_inputs[layer_name]
19201923
layer_input = to_device(layer_input, self.cache_device)
@@ -2093,9 +2096,9 @@ def _get_block_outputs(
20932096
tmp_input_ids, tmp_input_others = self._sampling_inputs(
20942097
input_ids, input_others, indices, self.seqlen, self.batch_dim, share_cache_keys=self.shared_cache_keys
20952098
)
2096-
tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to(
2097-
cache_device
2098-
)
2099+
tmp_output = self.block_forward(
2100+
block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device
2101+
).to(cache_device)
20992102
if save_output:
21002103
if self.batch_size == 1:
21012104
output.append(tmp_output)
@@ -2518,7 +2521,12 @@ def _quantize_layer(
25182521
if q_inputs is not None:
25192522
q_inputs[i] = q_inputs[i].to(layer.weight.dtype)
25202523

2521-
wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to(device)
2524+
wrapper_linear = WrapperLinear(
2525+
layer,
2526+
enable_minmax_tuning=self.enable_minmax_tuning,
2527+
enable_torch_compile=self.enable_torch_compile,
2528+
device=device,
2529+
).to(device)
25222530
round_params = []
25232531
minmax_params = []
25242532
for key in wrapper_linear.params.keys():
@@ -2696,7 +2704,7 @@ def _get_current_q_output(
26962704
batch_dim=self.batch_dim,
26972705
share_cache_keys=self.shared_cache_keys,
26982706
)
2699-
output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
2707+
output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
27002708
return output_q
27012709

27022710
def _get_current_num_elm(
@@ -2781,7 +2789,11 @@ def _quantize_block(
27812789
input_ids = q_input
27822790

27832791
quantized_layer_names, unquantized_layer_names = wrapper_block(
2784-
block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device
2792+
block,
2793+
self.enable_minmax_tuning,
2794+
self.enable_norm_bias_tuning,
2795+
enable_torch_compile=self.enable_torch_compile,
2796+
device=self.device,
27852797
)
27862798
if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse
27872799
from auto_round.data_type.utils import update_fused_layer_global_scales
@@ -3008,14 +3020,8 @@ def _quantize_blocks(
30083020
logger.info("using algorithm extension for quantization.")
30093021
except (ImportError, ModuleNotFoundError):
30103022
quantize_block = self._quantize_block
3011-
if self.enable_torch_compile:
3012-
quantize_block = compile_func(quantize_block, device)
3013-
else:
3014-
quantize_block = quantize_block
30153023
else:
30163024
quantize_block = self._quantize_block
3017-
if self.enable_torch_compile:
3018-
quantize_block = compile_func(quantize_block, device)
30193025

30203026
if pbar is None:
30213027
pbar = tqdm(range(0, len(block_names), nblocks))

auto_round/wrapper.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .utils import (
2323
SUPPORTED_LAYER_TYPES,
2424
check_to_quantized,
25+
compile_func,
2526
deepspeed_exists,
2627
get_scale_shape,
2728
is_mx_fp,
@@ -67,6 +68,7 @@ class WrapperLinear(torch.nn.Module):
6768
orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d).
6869
enable_minmax_tuning (bool): Whether to enable min-max scale tuning.
6970
enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term.
71+
enable_torch_compile (bool): Whether to enable torch compilation.
7072
device (str): Device on which to run computations (e.g., 'cpu' or 'cuda').
7173
"""
7274

@@ -77,6 +79,7 @@ def __init__(
7779
enable_norm_bias_tuning=False,
7880
device="cpu",
7981
enable_round_tuning=True,
82+
enable_torch_compile=False,
8083
**kwargs,
8184
):
8285
"""Initializes the WrapperLinear module.
@@ -93,6 +96,7 @@ def __init__(
9396
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
9497
self.enable_minmax_tuning = enable_minmax_tuning
9598
self.enable_round_tuning = enable_round_tuning
99+
self.enable_torch_compile = enable_torch_compile
96100
self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None)
97101
self.enable_act_quant = self.orig_layer.act_bits <= 8
98102
self.weight_global_scale = getattr(self.orig_layer, "weight_global_scale", None)
@@ -143,11 +147,15 @@ def _init_tuning_params_and_quant_func(self):
143147
self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))
144148

145149
self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym)
150+
if self.enable_torch_compile:
151+
self.weight_quant_func = compile_func(self.weight_quant_func, self.device)
146152

147153
if self.enable_act_quant:
148154
self.act_quant_func, self.act_data_type = get_quant_func(
149155
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym
150156
)
157+
if self.enable_torch_compile:
158+
self.act_quant_func = compile_func(self.act_quant_func, self.device)
151159
self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic)
152160

153161
## bias tuning
@@ -372,7 +380,10 @@ def _set_dict_attr(attr_dict, attr_name):
372380

373381
self.orig_layer.act_data_type = self.act_data_type
374382
self.orig_layer.act_quant_func = self.act_quant_func
375-
wrapper_layer = WrapperWALayer(self.orig_layer)
383+
wrapper_layer = WrapperWALayer(
384+
self.orig_layer,
385+
enable_torch_compile=self.enable_torch_compile,
386+
)
376387
return wrapper_layer
377388

378389
return self.orig_layer
@@ -452,12 +463,15 @@ def forward(self, x):
452463

453464

454465
class WrapperWALayer(torch.nn.Module):
455-
def __init__(self, orig_layer):
466+
def __init__(self, orig_layer, enable_torch_compile=False):
456467
super(WrapperWALayer, self).__init__()
457468
self.orig_layer = orig_layer
458469
self.data_type = orig_layer.data_type if hasattr(orig_layer, "data_type") else None
459470
self.act_data_type = orig_layer.act_data_type if hasattr(orig_layer, "act_data_type") else None
460471
self.act_quant_func = self.orig_layer.act_quant_func
472+
self.enable_torch_compile = enable_torch_compile
473+
if self.enable_torch_compile:
474+
self.act_quant_func = compile_func(self.act_quant_func, self.device)
461475
self.extra_repr_org = orig_layer.extra_repr
462476

463477
def forward(self, x):
@@ -609,12 +623,17 @@ def forward(self, x, **kwargs):
609623
return hidden_states
610624

611625

612-
def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs):
626+
def wrapper_block(
627+
block, enable_minmax_tuning, enable_norm_bias_tuning, enable_torch_compile=False, device="cpu", **kwargs
628+
):
613629
"""Wraps the layers in the given block with a custom Wrapper module.
614630
615631
Args:
616632
block: The input block containing linear and conv1d layers to be wrapped.
617633
enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled.
634+
enable_norm_bias_tuning: A boolean indicating whether normalization and bias tuning is enabled.
635+
enable_torch_compile: A boolean indicating whether to enable torch compilation.
636+
device: The device to which the wrapped layers should be moved.
618637
619638
Returns:
620639
list: A list of names of the wrapped layers and unwrapped layers.
@@ -630,6 +649,7 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="
630649
m,
631650
enable_minmax_tuning=enable_minmax_tuning,
632651
enable_norm_bias_tuning=enable_norm_bias_tuning,
652+
enable_torch_compile=enable_torch_compile,
633653
device=device,
634654
**kwargs,
635655
)

0 commit comments

Comments
 (0)