@@ -436,16 +436,17 @@ def tune(args):
436436 extra_config = ExtraConfig ()
437437 tuning_config = TuningExtraConfig (
438438 amp = not args .disable_amp ,
439- lr = args .lr ,
440- minmax_lr = args .minmax_lr ,
441- enable_quanted_input = not args .disable_quanted_input ,
442- nblocks = args .nblocks ,
439+ disable_opt_rtn = args .disable_opt_rtn ,
440+ enable_alg_ext = args .enable_alg_ext ,
443441 enable_minmax_tuning = not args .disable_minmax_tuning ,
444442 enable_norm_bias_tuning = args .enable_norm_bias_tuning ,
443+ enable_quanted_input = not args .disable_quanted_input ,
445444 enable_deterministic_algorithms = args .enable_deterministic_algorithms ,
445+ lr = args .lr ,
446+ minmax_lr = args .minmax_lr ,
447+ mem_per_param_scale = args .mem_per_param_scale ,
448+ nblocks = args .nblocks ,
446449 to_quant_block_names = args .to_quant_block_names ,
447- disable_opt_rtn = args .disable_opt_rtn ,
448- enable_alg_ext = args .enable_alg_ext ,
449450 scale_dtype = args .scale_dtype ,
450451 )
451452 scheme_config = SchemeExtraConfig (
@@ -459,6 +460,8 @@ def tune(args):
459460 act_dynamic = act_dynamic ,
460461 super_bits = args .super_bits ,
461462 super_group_size = args .super_group_size ,
463+ quant_lm_head = args .quant_lm_head ,
464+ fp_layers = args .fp_layers ,
462465 )
463466 mllm_config = MLLMExtraConfig (
464467 quant_nontext_module = args .quant_nontext_module , extra_data_dir = args .extra_data_dir , template = args .template
@@ -480,7 +483,6 @@ def tune(args):
480483 device_map = args .device_map ,
481484 enable_torch_compile = enable_torch_compile ,
482485 seed = args .seed ,
483- fp_layers = args .fp_layers ,
484486 not_use_best_mse = args .not_use_best_mse ,
485487 enable_adam = args .adam ,
486488 extra_config = extra_config ,
0 commit comments