Skip to content

Commit 70233bc

Browse files
authored
fix extra config (#847)
1 parent 0751337 commit 70233bc

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

auto_round/__main__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -436,16 +436,17 @@ def tune(args):
436436
extra_config = ExtraConfig()
437437
tuning_config = TuningExtraConfig(
438438
amp=not args.disable_amp,
439-
lr=args.lr,
440-
minmax_lr=args.minmax_lr,
441-
enable_quanted_input=not args.disable_quanted_input,
442-
nblocks=args.nblocks,
439+
disable_opt_rtn=args.disable_opt_rtn,
440+
enable_alg_ext=args.enable_alg_ext,
443441
enable_minmax_tuning=not args.disable_minmax_tuning,
444442
enable_norm_bias_tuning=args.enable_norm_bias_tuning,
443+
enable_quanted_input=not args.disable_quanted_input,
445444
enable_deterministic_algorithms=args.enable_deterministic_algorithms,
445+
lr=args.lr,
446+
minmax_lr=args.minmax_lr,
447+
mem_per_param_scale=args.mem_per_param_scale,
448+
nblocks=args.nblocks,
446449
to_quant_block_names=args.to_quant_block_names,
447-
disable_opt_rtn=args.disable_opt_rtn,
448-
enable_alg_ext=args.enable_alg_ext,
449450
scale_dtype=args.scale_dtype,
450451
)
451452
scheme_config = SchemeExtraConfig(
@@ -459,6 +460,8 @@ def tune(args):
459460
act_dynamic=act_dynamic,
460461
super_bits=args.super_bits,
461462
super_group_size=args.super_group_size,
463+
quant_lm_head=args.quant_lm_head,
464+
fp_layers=args.fp_layers,
462465
)
463466
mllm_config = MLLMExtraConfig(
464467
quant_nontext_module=args.quant_nontext_module, extra_data_dir=args.extra_data_dir, template=args.template
@@ -480,7 +483,6 @@ def tune(args):
480483
device_map=args.device_map,
481484
enable_torch_compile=enable_torch_compile,
482485
seed=args.seed,
483-
fp_layers=args.fp_layers,
484486
not_use_best_mse=args.not_use_best_mse,
485487
enable_adam=args.adam,
486488
extra_config=extra_config,

auto_round/compressors/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def __init__(
4242
minmax_lr: float = None,
4343
mem_per_param_scale: int = None,
4444
nblocks: int = 1,
45-
quant_lm_head: bool = False,
4645
to_quant_block_names: Union[str, list, None] = None,
4746
scale_dtype: str = "fp16",
4847
# scheme
@@ -58,6 +57,8 @@ def __init__(
5857
super_bits: int = None,
5958
super_group_size: int = None,
6059
static_kv_dtype: Union[str, torch.dtype] = None,
60+
quant_lm_head: bool = False,
61+
fp_layers: str = None,
6162
# mllm
6263
processor: Callable = None,
6364
image_processor: Callable = None,
@@ -116,7 +117,6 @@ def __init__(
116117
minmax_lr=minmax_lr,
117118
mem_per_param_scale=mem_per_param_scale,
118119
nblocks=nblocks,
119-
quant_lm_head=quant_lm_head,
120120
to_quant_block_names=to_quant_block_names,
121121
scale_dtype=scale_dtype,
122122
)
@@ -133,6 +133,8 @@ def __init__(
133133
super_bits=super_bits,
134134
super_group_size=super_group_size,
135135
static_kv_dtype=static_kv_dtype,
136+
quant_lm_head=quant_lm_head,
137+
fp_layers=fp_layers,
136138
)
137139
self.mllm_config = MLLMExtraConfig(
138140
processor=processor,
@@ -232,7 +234,6 @@ class TuningExtraConfig(BaseExtraConfig):
232234
minmax_lr: float = None
233235
mem_per_param_scale: int = None
234236
nblocks: int = 1
235-
quant_lm_head: bool = False
236237
to_quant_block_names: Union[str, list, None] = None
237238
scale_dtype: str = "fp16"
238239

@@ -251,6 +252,8 @@ class SchemeExtraConfig(BaseExtraConfig):
251252
super_bits: int = None
252253
super_group_size: int = None
253254
static_kv_dtype: Union[str, torch.dtype] = None
255+
quant_lm_head: bool = False
256+
fp_layers: str = None
254257

255258

256259
@dataclass

0 commit comments

Comments
 (0)