diff --git a/llm/utils/argument.py b/llm/utils/argument.py index bd822261c283..099c08ebe7e8 100644 --- a/llm/utils/argument.py +++ b/llm/utils/argument.py @@ -299,11 +299,6 @@ class QuantArgument: do_ptq: bool = field(default=False, metadata={"help": "Whether to use PTQ"}) ptq_step: int = field(default=32, metadata={"help": "Step for PTQ"}) - weight_quant_method: str = field( - default="abs_max_channel_wise", - metadata={"help": "Weight quantization method, choosen from ['abs_max_channel_wise', 'groupwise']"}, - ) - # Pre-quant method Shift related parameters shift: bool = field(default=False, metadata={"help": "Whether to use Shift"}) shift_all_linears: bool = field(default=False, metadata={"help": "Whether to shift all linears"}) diff --git a/llm/utils/quant.py b/llm/utils/quant.py index b9520e72edc3..11a22fe6fa5b 100644 --- a/llm/utils/quant.py +++ b/llm/utils/quant.py @@ -69,7 +69,7 @@ WEIGHT_OBSERVER = dict( abs_max_channel_wise=AbsMaxChannelWiseWeightObserver, - group_wise=GroupWiseWeightObserver, + groupwise=GroupWiseWeightObserver, ) CACHEKV_OBSERVER = dict( @@ -259,12 +259,12 @@ def prepare_qconfig(args): activation = act_observer(quant_bits=a_quant_bit) weight = weight_observer(quant_bits=w_quant_bit) - elif quant_type in ["wint4", "w4a16", "weight_only_int8"]: + elif quant_type in ["wint4", "w4a16", "weight_only_int4"]: activation = None - weight = GroupWiseWeightObserver(quant_bits=4, group_size=args.group_size) # TODO + weight = weight_observer(quant_bits=4) + elif quant_type in ["wint8", "w8a16", "weight_only_int8"]: activation = None - if "w" in args.use_fp8: weight = weight_observer(quant_bits=(4, 3)) else: