Skip to content

Commit

Permalink
[Inference] Fix weight_only_int4 bug (#9073)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixcli authored Sep 4, 2024
1 parent fbbc0a2 commit 70da482
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
5 changes: 0 additions & 5 deletions llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,6 @@ class QuantArgument:
do_ptq: bool = field(default=False, metadata={"help": "Whether to use PTQ"})
ptq_step: int = field(default=32, metadata={"help": "Step for PTQ"})

weight_quant_method: str = field(
default="abs_max_channel_wise",
metadata={"help": "Weight quantization method, choosen from ['abs_max_channel_wise', 'groupwise']"},
)

# Pre-quant method Shift related parameters
shift: bool = field(default=False, metadata={"help": "Whether to use Shift"})
shift_all_linears: bool = field(default=False, metadata={"help": "Whether to shift all linears"})
Expand Down
8 changes: 4 additions & 4 deletions llm/utils/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

WEIGHT_OBSERVER = dict(
abs_max_channel_wise=AbsMaxChannelWiseWeightObserver,
group_wise=GroupWiseWeightObserver,
groupwise=GroupWiseWeightObserver,
)

CACHEKV_OBSERVER = dict(
Expand Down Expand Up @@ -259,12 +259,12 @@ def prepare_qconfig(args):
activation = act_observer(quant_bits=a_quant_bit)
weight = weight_observer(quant_bits=w_quant_bit)

elif quant_type in ["wint4", "w4a16", "weight_only_int8"]:
elif quant_type in ["wint4", "w4a16", "weight_only_int4"]:
activation = None
weight = GroupWiseWeightObserver(quant_bits=4, group_size=args.group_size) # TODO
weight = weight_observer(quant_bits=4)

elif quant_type in ["wint8", "w8a16", "weight_only_int8"]:
activation = None

if "w" in args.use_fp8:
weight = weight_observer(quant_bits=(4, 3))
else:
Expand Down

0 comments on commit 70da482

Please sign in to comment.