Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class BaseCompressor(object):
sym (bool): Whether to use symmetric weight quantization.
layer_config (dict): Per-layer quantization configuration.
nsamples (int): Number of calibration samples.
enable_torch_compile (bool): Whether to enable torch.compile for quant blocks/layers.
enable_torch_compile (bool): Whether to enable compile_func for quant blocks/layers.
"""

bits: int | None
Expand Down Expand Up @@ -361,6 +361,7 @@ def __init__(
self.infer_bs_coeff = 1
self.enable_torch_compile = enable_torch_compile
self._adjust_torch_compile(enable_torch_compile)
self.block_forward = compile_func(block_forward, self.device) if self.enable_torch_compile else block_forward
self._check_configs()
torch.set_printoptions(precision=3, sci_mode=True)

Expand Down Expand Up @@ -1428,6 +1429,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
enable_minmax_tuning=False,
enable_norm_bias_tuning=False,
enable_round_tuning=False,
enable_torch_compile=self.enable_torch_compile,
)
m = m.unwrapper({})
m.to("cpu")
Expand All @@ -1443,6 +1445,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
enable_minmax_tuning=False,
enable_norm_bias_tuning=False,
enable_round_tuning=False,
enable_torch_compile=self.enable_torch_compile,
)
m = m.unwrapper({})
except Exception as e:
Expand Down Expand Up @@ -1882,6 +1885,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
enable_round_tuning=False,
enable_minmax_tuning=False,
enable_norm_bias_tuning=False,
enable_torch_compile=self.enable_torch_compile,
device=self.device,
)
new_layer = wrapper_layer.unwrapper({})
Expand Down Expand Up @@ -1911,10 +1915,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:

self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
clear_memory()
if self.enable_torch_compile:
quant_layer = compile_func(self._quantize_layer, self.device)
else:
quant_layer = self._quantize_layer
quant_layer = self._quantize_layer
for layer_name in layer_names:
layer_input = layer_inputs[layer_name]
layer_input = to_device(layer_input, self.cache_device)
Expand Down Expand Up @@ -2093,9 +2094,9 @@ def _get_block_outputs(
tmp_input_ids, tmp_input_others = self._sampling_inputs(
input_ids, input_others, indices, self.seqlen, self.batch_dim, share_cache_keys=self.shared_cache_keys
)
tmp_output = block_forward(block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device).to(
cache_device
)
tmp_output = self.block_forward(
block, tmp_input_ids, tmp_input_others, self.amp, self.amp_dtype, device
).to(cache_device)
if save_output:
if self.batch_size == 1:
output.append(tmp_output)
Expand Down Expand Up @@ -2518,7 +2519,12 @@ def _quantize_layer(
if q_inputs is not None:
q_inputs[i] = q_inputs[i].to(layer.weight.dtype)

wrapper_linear = WrapperLinear(layer, enable_minmax_tuning=self.enable_minmax_tuning, device=device).to(device)
wrapper_linear = WrapperLinear(
layer,
enable_minmax_tuning=self.enable_minmax_tuning,
enable_torch_compile=self.enable_torch_compile,
device=device,
).to(device)
round_params = []
minmax_params = []
for key in wrapper_linear.params.keys():
Expand Down Expand Up @@ -2696,7 +2702,7 @@ def _get_current_q_output(
batch_dim=self.batch_dim,
share_cache_keys=self.shared_cache_keys,
)
output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
output_q = self.block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device)
return output_q

def _get_current_num_elm(
Expand Down Expand Up @@ -2781,7 +2787,11 @@ def _quantize_block(
input_ids = q_input

quantized_layer_names, unquantized_layer_names = wrapper_block(
block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, device=self.device
block,
self.enable_minmax_tuning,
self.enable_norm_bias_tuning,
enable_torch_compile=self.enable_torch_compile,
device=self.device,
)
if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse
from auto_round.data_type.utils import update_fused_layer_global_scales
Expand Down Expand Up @@ -3008,14 +3018,8 @@ def _quantize_blocks(
logger.info("using algorithm extension for quantization.")
except (ImportError, ModuleNotFoundError):
quantize_block = self._quantize_block
if self.enable_torch_compile:
quantize_block = compile_func(quantize_block, device)
else:
quantize_block = quantize_block
else:
quantize_block = self._quantize_block
if self.enable_torch_compile:
quantize_block = compile_func(quantize_block, device)

if pbar is None:
pbar = tqdm(range(0, len(block_names), nblocks))
Expand Down
28 changes: 25 additions & 3 deletions auto_round/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .utils import (
SUPPORTED_LAYER_TYPES,
check_to_quantized,
compile_func,
deepspeed_exists,
get_scale_shape,
is_mx_fp,
Expand Down Expand Up @@ -67,6 +68,7 @@ class WrapperLinear(torch.nn.Module):
orig_layer (torch.nn.Module): The original layer to be wrapped (linear or conv1d).
enable_minmax_tuning (bool): Whether to enable min-max scale tuning.
enable_norm_bias_tuning (bool): Whether to enable normalization and tuning of the bias term.
enable_torch_compile (bool): Whether to enable torch compilation.
device (str): Device on which to run computations (e.g., 'cpu' or 'cuda').
"""

Expand All @@ -77,6 +79,7 @@ def __init__(
enable_norm_bias_tuning=False,
device="cpu",
enable_round_tuning=True,
enable_torch_compile=False,
**kwargs,
):
"""Initializes the WrapperLinear module.
Expand All @@ -93,6 +96,7 @@ def __init__(
self.device = self.orig_layer.tuning_device if hasattr(self.orig_layer, "tuning_device") else device
self.enable_minmax_tuning = enable_minmax_tuning
self.enable_round_tuning = enable_round_tuning
self.enable_torch_compile = enable_torch_compile
self.enable_norm_bias_tuning = enable_norm_bias_tuning and (orig_layer.bias is not None)
self.enable_act_quant = self.orig_layer.act_bits <= 8
self.weight_global_scale = getattr(self.orig_layer, "weight_global_scale", None)
Expand Down Expand Up @@ -143,11 +147,15 @@ def _init_tuning_params_and_quant_func(self):
self._init_params("max_scale", p_dtype, shape, 1.0, (self.enable_minmax_tuning and self.orig_layer.bits < 16))

self.weight_quant_func, self.data_type = get_quant_func(orig_layer.data_type, orig_layer.bits, orig_layer.sym)
if self.enable_torch_compile:
self.weight_quant_func = compile_func(self.weight_quant_func, self.device)

if self.enable_act_quant:
self.act_quant_func, self.act_data_type = get_quant_func(
orig_layer.act_data_type, orig_layer.act_bits, orig_layer.act_sym
)
if self.enable_torch_compile:
self.act_quant_func = compile_func(self.act_quant_func, self.device)
self._init_params("act_max_scale", p_dtype, (1), 1.0, not orig_layer.act_dynamic)

## bias tuning
Expand Down Expand Up @@ -372,7 +380,11 @@ def _set_dict_attr(attr_dict, attr_name):

self.orig_layer.act_data_type = self.act_data_type
self.orig_layer.act_quant_func = self.act_quant_func
wrapper_layer = WrapperWALayer(self.orig_layer)
wrapper_layer = WrapperWALayer(
self.orig_layer,
enable_torch_compile=self.enable_torch_compile,
device=self.device,
)
return wrapper_layer

return self.orig_layer
Expand Down Expand Up @@ -452,12 +464,16 @@ def forward(self, x):


class WrapperWALayer(torch.nn.Module):
def __init__(self, orig_layer):
def __init__(self, orig_layer, enable_torch_compile=False, device="cpu"):
super(WrapperWALayer, self).__init__()
self.orig_layer = orig_layer
self.enable_torch_compile = enable_torch_compile
self.device = device
self.data_type = orig_layer.data_type if hasattr(orig_layer, "data_type") else None
self.act_data_type = orig_layer.act_data_type if hasattr(orig_layer, "act_data_type") else None
self.act_quant_func = self.orig_layer.act_quant_func
if self.enable_torch_compile:
self.act_quant_func = compile_func(self.act_quant_func, self.device)
self.extra_repr_org = orig_layer.extra_repr

def forward(self, x):
Expand Down Expand Up @@ -609,12 +625,17 @@ def forward(self, x, **kwargs):
return hidden_states


def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="cpu", **kwargs):
def wrapper_block(
block, enable_minmax_tuning, enable_norm_bias_tuning, enable_torch_compile=False, device="cpu", **kwargs
):
"""Wraps the layers in the given block with a custom Wrapper module.

Args:
block: The input block containing linear and conv1d layers to be wrapped.
enable_minmax_tuning: A boolean indicating whether min-max tuning is enabled.
enable_norm_bias_tuning: A boolean indicating whether normalization and bias tuning is enabled.
enable_torch_compile: A boolean indicating whether to enable torch compilation.
device: The device to which the wrapped layers should be moved.

Returns:
list: A list of names of the wrapped layers and unwrapped layers.
Expand All @@ -630,6 +651,7 @@ def wrapper_block(block, enable_minmax_tuning, enable_norm_bias_tuning, device="
m,
enable_minmax_tuning=enable_minmax_tuning,
enable_norm_bias_tuning=enable_norm_bias_tuning,
enable_torch_compile=enable_torch_compile,
device=device,
**kwargs,
)
Expand Down