Skip to content

Commit d9907aa

Browse files
committed
speedup quant and evaluation, fix recompile issue
Signed-off-by: He, Xin3 <xin3.he@intel.com>
1 parent f2b9aef commit d9907aa

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

auto_round/compressors/base.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def __init__(
361361
self.infer_bs_coeff = 1
362362
self.enable_torch_compile = enable_torch_compile
363363
self._adjust_torch_compile(enable_torch_compile)
364+
os.environ["AR_TORCH_COMPILE"] = "1" if self.enable_torch_compile else "0"
364365
self._check_configs()
365366
torch.set_printoptions(precision=3, sci_mode=True)
366367

@@ -1911,10 +1912,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
19111912

19121913
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
19131914
clear_memory()
1914-
if self.enable_torch_compile:
1915-
quant_layer = compile_func(self._quantize_layer, self.device)
1916-
else:
1917-
quant_layer = self._quantize_layer
1915+
quant_layer = self._quantize_layer
19181916
for layer_name in layer_names:
19191917
layer_input = layer_inputs[layer_name]
19201918
layer_input = to_device(layer_input, self.cache_device)
@@ -3008,14 +3006,8 @@ def _quantize_blocks(
30083006
logger.info("using algorithm extension for quantization.")
30093007
except (ImportError, ModuleNotFoundError):
30103008
quantize_block = self._quantize_block
3011-
if self.enable_torch_compile:
3012-
quantize_block = compile_func(quantize_block, device)
3013-
else:
3014-
quantize_block = quantize_block
30153009
else:
30163010
quantize_block = self._quantize_block
3017-
if self.enable_torch_compile:
3018-
quantize_block = compile_func(quantize_block, device)
30193011

30203012
if pbar is None:
30213013
pbar = tqdm(range(0, len(block_names), nblocks))

auto_round/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
environment_variables: dict[str, Callable[[], Any]] = {
2323
# this is used for configuring the default logging level
2424
"AR_LOG_LEVEL": lambda: os.getenv("AR_LOG_LEVEL", "INFO").upper(),
25+
"AR_TORCH_COMPILE": lambda: os.getenv("AR_TORCH_COMPILE", "0"),
2526
}
2627

2728

auto_round/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,12 @@ def is_hpex_available():
171171
return _hpex_available
172172

173173

174+
@torch._dynamo.disable()
175+
@lru_cache(None)
176+
def is_torch_compile_enabled():
177+
return os.getenv("AR_TORCH_COMPILE", "0") in ("1", "true", "True")
178+
179+
174180
def get_module(module, key):
175181
"""Get module from model by key name.
176182
@@ -506,6 +512,9 @@ def block_forward(
506512
output = output[output_return_id]
507513
return output
508514

515+
if is_torch_compile_enabled():
516+
block_forward = torch.compile(block_forward)
517+
509518

510519
def check_to_quantized(config):
511520
"""Checks if the configuration is valid for quantization.

auto_round/wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
is_mx_fp,
2828
is_nv_fp,
2929
set_module,
30+
is_torch_compile_enabled,
3031
)
3132

3233
if deepspeed_exists:
@@ -143,11 +144,15 @@ def _init_tuning_params_and_quant_func(self):
143144
self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))
144145

145146
self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym)
147+
if is_torch_compile_enabled():
148+
self.weight_quant_func = torch.compile(self.weight_quant_func)
146149

147150
if self.enable_act_quant:
148151
self.act_quant_func, self.act_data_type = get_quant_func(
149152
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym
150153
)
154+
if is_torch_compile_enabled():
155+
self.act_quant_func = torch.compile(self.act_quant_func)
151156
self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic)
152157

153158
## bias tuning

0 commit comments

Comments
 (0)