Skip to content

Commit 9d62109

Browse files
authored
speedup quant and evaluation, fix recompile issue (#897)
* rewrite the implementation for ease-of-maintain Signed-off-by: He, Xin3 <xin3.he@intel.com> * fix bug Signed-off-by: He, Xin3 <xin3.he@intel.com> * fix quant performance Signed-off-by: He, Xin3 <xin3.he@intel.com> * Update auto_round/compressors/base.py --------- Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent 1b804be commit 9d62109

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

auto_round/compressors/base.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class BaseCompressor(object):
111111
sym (bool): Whether to use symmetric weight quantization.
112112
layer_config (dict): Per-layer quantization configuration.
113113
nsamples (int): Number of calibration samples.
114-
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
114+
enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers.
115115
"""
116116

117117
bits: int | None
@@ -361,6 +361,7 @@ def __init__(
361361
self.infer_bs_coeff = 1
362362
self.enable_torch_compile = enable_torch_compile
363363
self._adjust_torch_compile(enable_torch_compile)
364+
self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward
364365
self._check_configs()
365366
torch.set_printoptions(precision=3, sci_mode=True)
366367

@@ -1428,6 +1429,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14281429
enable_minmax_tuning=False,
14291430
enable_norm_bias_tuning=False,
14301431
enable_round_tuning=False,
1432+
enable_torch_compile=self.enable_torch_compile,
14311433
)
14321434
m = m.unwrapper({})
14331435
m.to("cpu")
@@ -1443,6 +1445,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
14431445
enable_minmax_tuning=False,
14441446
enable_norm_bias_tuning=False,
14451447
enable_round_tuning=False,
1448+
enable_torch_compile=self.enable_torch_compile,
14461449
)
14471450
m = m.unwrapper({})
14481451
except Exception as e:
@@ -1882,6 +1885,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
18821885
enable_round_tuning=False,
18831886
enable_minmax_tuning=False,
18841887
enable_norm_bias_tuning=False,
1888+
enable_torch_compile=self.enable_torch_compile,
18851889
device=self.device,
18861890
)
18871891
new_layer = wrapper_layer.unwrapper({})
@@ -1911,10 +1915,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
19111915

19121916
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
19131917
clear_memory()
1914-
if self.enable_torch_compile:
1915-
quant_layer = compile_func(self._quantize_layer, self.device)
1916-
else:
1917-
quant_layer = self._quantize_layer
1918+
quant_layer = self._quantize_layer
19181919
for layer_name in layer_names:
19191920
layer_input = layer_inputs[layer_name]
19201921
layer_input = to_device(layer_input, self.cache_device)
@@ -2093,9 +2094,9 @@ def _get_block_outputs(
20932094
tmp_input_ids, tmp_input_others = self._sampling_inputs(
20942095
input_ids, input_others, indices, self.seqlen, self.batch_dim, share_cache_keys=self.shared_cache_keys
20952096
)
2096-
tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to(
2097-
cache_device
2098-
)
2097+
tmp_output = self.block_forward(
2098+
block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device
2099+
).to(cache_device)
20992100
if save_output:
21002101
if self.batch_size == 1:
21012102
output.append(tmp_output)
@@ -2518,7 +2519,12 @@ def _quantize_layer(
25182519
if q_inputs is not None:
25192520
q_inputs[i] = q_inputs[i].to(layer.weight.dtype)
25202521

2521-
wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to(device)
2522+
wrapper_linear = WrapperLinear(
2523+
layer,
2524+
enable_minmax_tuning=self.enable_minmax_tuning,
2525+
enable_torch_compile=self.enable_torch_compile,
2526+
device=device,
2527+
).to(device)
25222528
round_params = []
25232529
minmax_params = []
25242530
for key in wrapper_linear.params.keys():
@@ -2696,7 +2702,7 @@ def _get_current_q_output(
26962702
batch_dim=self.batch_dim,
26972703
share_cache_keys=self.shared_cache_keys,
26982704
)
2699-
output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
2705+
output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
27002706
return output_q
27012707

27022708
def _get_current_num_elm(
@@ -2781,7 +2787,11 @@ def _quantize_block(
27812787
input_ids = q_input
27822788

27832789
quantized_layer_names, unquantized_layer_names = wrapper_block(
2784-
block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device
2790+
block,
2791+
self.enable_minmax_tuning,
2792+
self.enable_norm_bias_tuning,
2793+
enable_torch_compile=self.enable_torch_compile,
2794+
device=self.device,
27852795
)
27862796
if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse
27872797
from auto_round.data_type.utils import update_fused_layer_global_scales
@@ -3008,14 +3018,8 @@ def _quantize_blocks(
30083018
logger.info("using algorithm extension for quantization.")
30093019
except (ImportError, ModuleNotFoundError):
30103020
quantize_block = self._quantize_block
3011-
if self.enable_torch_compile:
3012-
quantize_block = compile_func(quantize_block, device)
3013-
else:
3014-
quantize_block = quantize_block
30153021
else:
30163022
quantize_block = self._quantize_block
3017-
if self.enable_torch_compile:
3018-
quantize_block = compile_func(quantize_block, device)
30193023

30203024
if pbar is None:
30213025
pbar = tqdm(range(0, len(block_names), nblocks))

auto_round/wrapper.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .utils import (
2323
SUPPORTED_LAYER_TYPES,
2424
check_to_quantized,
25+
compile_func,
2526
deepspeed_exists,
2627
get_scale_shape,
2728
is_mx_fp,
@@ -67,6 +68,7 @@ class WrapperLinear(torch.nn.Module):
6768
orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d).
6869
enable_minmax_tuning (bool): Whether to enable min-max scale tuning.
6970
enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term.
71+
enable_torch_compile (bool): Whether to enable torch compilation.
7072
device (str): Device on which to run computations (e.g., 'cpu' or 'cuda').
7173
"""
7274

@@ -77,6 +79,7 @@ def __init__(
7779
enable_norm_bias_tuning=False,
7880
device="cpu",
7981
enable_round_tuning=True,
82+
enable_torch_compile=False,
8083
**kwargs,
8184
):
8285
"""Initializes the WrapperLinear module.
@@ -93,6 +96,7 @@ def __init__(
9396
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
9497
self.enable_minmax_tuning = enable_minmax_tuning
9598
self.enable_round_tuning = enable_round_tuning
99+
self.enable_torch_compile = enable_torch_compile
96100
self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None)
97101
self.enable_act_quant = self.orig_layer.act_bits <= 8
98102
self.weight_global_scale = getattr(self.orig_layer, "weight_global_scale", None)
@@ -143,11 +147,15 @@ def _init_tuning_params_and_quant_func(self):
143147
self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))
144148

145149
self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym)
150+
if self.enable_torch_compile:
151+
self.weight_quant_func = compile_func(self.weight_quant_func, self.device)
146152

147153
if self.enable_act_quant:
148154
self.act_quant_func, self.act_data_type = get_quant_func(
149155
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym
150156
)
157+
if self.enable_torch_compile:
158+
self.act_quant_func = compile_func(self.act_quant_func, self.device)
151159
self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic)
152160

153161
## bias tuning
@@ -372,7 +380,11 @@ def _set_dict_attr(attr_dict, attr_name):
372380

373381
self.orig_layer.act_data_type = self.act_data_type
374382
self.orig_layer.act_quant_func = self.act_quant_func
375-
wrapper_layer = WrapperWALayer(self.orig_layer)
383+
wrapper_layer = WrapperWALayer(
384+
self.orig_layer,
385+
enable_torch_compile=self.enable_torch_compile,
386+
device=self.device,
387+
)
376388
return wrapper_layer
377389

378390
return self.orig_layer
@@ -452,12 +464,16 @@ def forward(self, x):
452464

453465

454466
class WrapperWALayer(torch.nn.Module):
455-
def __init__(self, orig_layer):
467+
def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"):
456468
super(WrapperWALayer, self).__init__()
457469
self.orig_layer = orig_layer
470+
self.enable_torch_compile = enable_torch_compile
471+
self.device = device
458472
self.data_type = orig_layer.data_type if hasattr(orig_layer, "data_type") else None
459473
self.act_data_type = orig_layer.act_data_type if hasattr(orig_layer, "act_data_type") else None
460474
self.act_quant_func = self.orig_layer.act_quant_func
475+
if self.enable_torch_compile:
476+
self.act_quant_func = compile_func(self.act_quant_func, self.device)
461477
self.extra_repr_org = orig_layer.extra_repr
462478

463479
def forward(self, x):
@@ -609,12 +625,17 @@ def forward(self, x, **kwargs):
609625
return hidden_states
610626

611627

612-
def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs):
628+
def wrapper_block(
629+
block, enable_minmax_tuning, enable_norm_bias_tuning, enable_torch_compile=False, device="cpu", **kwargs
630+
):
613631
"""Wraps the layers in the given block with a custom Wrapper module.
614632
615633
Args:
616634
block: The input block containing linear and conv1d layers to be wrapped.
617635
enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled.
636+
enable_norm_bias_tuning: A boolean indicating whether normalization and bias tuning is enabled.
637+
enable_torch_compile: A boolean indicating whether to enable torch compilation.
638+
device: The device to which the wrapped layers should be moved.
618639
619640
Returns:
620641
list: A list of names of the wrapped layers and unwrapped layers.
@@ -630,6 +651,7 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="
630651
m,
631652
enable_minmax_tuning=enable_minmax_tuning,
632653
enable_norm_bias_tuning=enable_norm_bias_tuning,
654+
enable_torch_compile=enable_torch_compile,
633655
device=device,
634656
**kwargs,
635657
)

0 commit comments

Comments
 (0)