Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,6 @@ def __init__(self, *args, **kwargs):
type=float,
help="Learning rate specifically for min-max tuning. " "If None, uses the same value as --lr. ",
)
tuning.add_argument(
"--mem_per_param_scale",
default=13,
type=float,
help="Memory scaling factor for parameter memory estimation. "
"Adjust this if you need to control memory usage during tuning. "
"Lower values reduce memory usage but may affect accuracy.",
)
tuning.add_argument(
"--gradient_accumulate_steps",
default=1,
Expand Down Expand Up @@ -529,7 +521,6 @@ def tune(args):
enable_deterministic_algorithms=args.enable_deterministic_algorithms,
lr=args.lr,
minmax_lr=args.minmax_lr,
mem_per_param_scale=args.mem_per_param_scale,
nblocks=args.nblocks,
to_quant_block_names=args.to_quant_block_names,
scale_dtype=args.scale_dtype,
Expand Down
21 changes: 9 additions & 12 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
unsupported_meta_device,
)
from auto_round.utils.device import (
clear_memory_if_reached_threshold,
get_major_device,
set_auto_device_map_for_block_with_tuning,
set_non_auto_device_map,
Expand Down Expand Up @@ -229,9 +230,6 @@ def __init__(
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
device = kwargs.pop("device", None)
# Scale factor for RAM usage per parameter.
mem_per_param_scale = kwargs.pop("mem_per_param_scale", None)

if envs.AR_USE_MODELSCOPE:
platform = "model_scope"
self.platform = platform
Expand Down Expand Up @@ -336,10 +334,6 @@ def __init__(
self.optimizer = self._get_optimizer(None)
self.disable_opt_rtn = disable_opt_rtn
self.is_packing_immediate = False # whether to pack the layer immediately after tuning
if mem_per_param_scale is None:
self.mem_per_param_scale = 13 if self.iters != 0 else 1
else:
self.mem_per_param_scale = mem_per_param_scale

# KV cache, this one does not affect tuning but will collect some infos during tuning
self.static_kv_dtype = static_kv_dtype
Expand Down Expand Up @@ -1436,7 +1430,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])

if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size
)
# Dispatch model if needed
if self.device_map is not None:
Expand Down Expand Up @@ -2454,10 +2448,12 @@ def _quantize_block(
new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype).to(device)
set_module(block, n, new_layer)

if self.device_map == "auto" or (isinstance(self.device_map, str) and "," in self.device_map):
if self.device_map == "auto" or ((isinstance(self.device_map, str) and "," in self.device_map)):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.mem_per_param_scale
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
)
else:
block = block.to(device)

if self.device_map is not None:
for n, m in block.named_modules():
Expand Down Expand Up @@ -2508,7 +2504,7 @@ def _quantize_block(
self.enable_minmax_tuning,
self.enable_norm_bias_tuning,
enable_torch_compile=self.enable_torch_compile,
device=self.device,
device=device,
)
if is_nv_fp(self.data_type): # enable qkv and moe structure global_scale fuse
from auto_round.data_type.utils import update_fused_layer_global_scales
Expand Down Expand Up @@ -2588,6 +2584,7 @@ def _quantize_block(
current_output = to_device(current_output, device)

output_q = self._get_current_q_output(block, input_ids, input_others, indices, device)

if self.attention_mask:
tmp_attention_mask = [self.attention_mask[i] for i in indices]
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
Expand All @@ -2607,6 +2604,7 @@ def _quantize_block(

total_loss += loss.item() / num_elm
self._scale_loss_and_backward(scaler, loss)
clear_memory_if_reached_threshold(threshold=0.85)

if i == 0:
init_loss = total_loss
Expand Down Expand Up @@ -2762,7 +2760,6 @@ def _quantize_blocks(
modules = [get_module(model, n) for n in names]
m = WrapperMultiblock(modules)

m = m.to(device)
q_input, input_ids = quantize_block(
m,
input_ids,
Expand Down
5 changes: 0 additions & 5 deletions auto_round/compressors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
lr: float = None,
lr_scheduler: Callable = None,
minmax_lr: float = None,
mem_per_param_scale: int = None,
nblocks: int = 1,
to_quant_block_names: Union[str, list, None] = None,
scale_dtype: str = "fp16",
Expand Down Expand Up @@ -84,8 +83,6 @@ def __init__(
lr (float): The learning rate (default is 0.005).
lr_scheduler: The learning rate scheduler to be used.
minmax_lr (float): The learning rate for min-max tuning (default is None).
mem_per_param_scale (int): Scale factor for memory per parameter,
used to adjust memory usage estimation for tuning.
nblocks (int): Number of blocks (default is 1).
quant_lm_head (bool): Whether to quant lm_head.
to_quant_block_names (str|list): Names of quantitative blocks, please use commas to separate them.
Expand Down Expand Up @@ -124,7 +121,6 @@ def __init__(
lr=lr,
lr_scheduler=lr_scheduler,
minmax_lr=minmax_lr,
mem_per_param_scale=mem_per_param_scale,
nblocks=nblocks,
to_quant_block_names=to_quant_block_names,
scale_dtype=scale_dtype,
Expand Down Expand Up @@ -260,7 +256,6 @@ class TuningExtraConfig(BaseExtraConfig):
lr: float = None
lr_scheduler: Callable = None
minmax_lr: float = None
mem_per_param_scale: int = None
nblocks: int = 1
to_quant_block_names: Union[str, list, None] = None
scale_dtype: str = "fp16"
Expand Down
Loading