Skip to content

support gptq true_sequential and quant_lm_head #1977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .azure-pipelines/scripts/ut/run_itrex.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ bash /intel-extension-for-transformers/.github/workflows/script/install_binary.s
sed -i '/neural-compressor.git/d' /intel-extension-for-transformers/tests/requirements.txt
pip install -r /intel-extension-for-transformers/tests/requirements.txt
# workaround
pip install onnx==1.15.0
pip install onnx==1.16.0
pip install onnxruntime==1.18.0
echo "pip list itrex ut deps..."
pip list
LOG_DIR=/neural-compressor/log_dir
Expand Down
5 changes: 3 additions & 2 deletions docs/source/3x/PT_WeightOnlyQuant.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ model = convert(model)
| model_path (str) | Model path that is used to load state_dict per layer | |
| use_double_quant (bool) | Enables double quantization | False |
| act_order (bool) | Whether to sort Hessian's diagonal values to rearrange channel-wise quantization order | False |
| percdamp (float) | Percentage of Hessian's diagonal values' average, which will be added to Hessian's diagonal to increase numerical stability | 0.01. |
| percdamp (float) | Percentage of Hessian's diagonal values' average, which will be added to Hessian's diagonal to increase numerical stability | 0.01 |
| block_size (int) | Execute GPTQ quantization per block, block shape = [C_out, block_size] | 128 |
| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. |
| static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False |
| true_sequential (bool) | Whether to quantize layers within a transformer block in their original order. This can lead to higher accuracy but slower overall quantization process. | False |
> **Note:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned.

``` python
Expand Down
422 changes: 334 additions & 88 deletions neural_compressor/torch/algorithms/weight_only/gptq.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions neural_compressor/torch/algorithms/weight_only/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def convert(
if dtype != "int" and "int" in dtype:
bits = int(dtype.lstrip("int"))
dtype = "int"
else:
continue
log_msg = (
f"RTN quantization config: bits={bits}, group_size={group_size}, "
+ f"scheme={scheme}, quantile={quantile}"
Expand Down
2 changes: 2 additions & 0 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,13 @@ def gptq_entry(
"percdamp": quant_config.percdamp,
"block_size": quant_config.block_size,
"static_groups": quant_config.static_groups,
"true_sequential": quant_config.true_sequential,
}
kwargs.update(
{
"use_layer_wise": quant_config.use_layer_wise,
"model_path": quant_config.model_path,
"quant_lm_head": quant_config.quant_lm_head,
}
)
kwargs.pop("example_inputs")
Expand Down
9 changes: 7 additions & 2 deletions neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class GPTQConfig(TorchBaseConfig):
"percdamp",
"block_size",
"static_groups",
"true_sequential",
]

def __init__(
Expand All @@ -376,6 +377,7 @@ def __init__(
percdamp: float = 0.01,
block_size: int = 2048,
static_groups: bool = False,
true_sequential: bool = False,
# Tuning space
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
):
Expand Down Expand Up @@ -404,10 +406,12 @@ def __init__(
static_groups (bool): Whether to calculate group wise quantization parameters in advance.
This option mitigate actorder's extra computational requirements.
Default is False.
true_sequential (bool): Whether to quantize layers within a transformer block in their original order.
This can lead to higher accuracy but slower overall quantization process.
Default is False.
white_list (Optional[List[OP_NAME_OR_MODULE_TYPE]]): White list of operator names or module types.
Default is DEFAULT_WHITE_LIST.
"""
assert not quant_lm_head, "GPTQ doesn't support lm_head quantization currently, it's coming soon!"
super().__init__(white_list=white_list)
self.dtype = dtype
self.bits = bits
Expand All @@ -428,6 +432,7 @@ def __init__(
self.percdamp = percdamp
self.block_size = block_size
self.static_groups = static_groups
self.true_sequential = true_sequential
self.quant_lm_head = quant_lm_head
self._post_init() # initialize global & local configuration

Expand Down Expand Up @@ -599,7 +604,7 @@ def __init__(
double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4.
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
double_quant_group_size (int): Size of double_quant groups, default is 32.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformer, default is False.
use_auto_scale (bool): Enables best scales search based on activation distribution, default is True.
use_auto_clip (bool): Enables clip range search. Defaults to True.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
Expand Down
73 changes: 69 additions & 4 deletions test/3x/torch/quantization/weight_only/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def test_act_order(self):
# compare atol, this case is an ideal case.
assert atol_false > atol_true, "act_order=True doesn't help accuracy, maybe is reasonable, please double check."

def test_layer_wise(self):
@pytest.mark.parametrize("quant_lm_head", [False, True])
def test_layer_wise(self, quant_lm_head):
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig()
quant_config = GPTQConfig(quant_lm_head=quant_lm_head)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
Expand All @@ -194,12 +195,76 @@ def test_layer_wise(self):

model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM")

quant_config = GPTQConfig(use_layer_wise=True, model_path="hf-internal-testing/tiny-random-GPTJForCausalLM")
quant_config = GPTQConfig(
use_layer_wise=True,
quant_lm_head=quant_lm_head,
model_path="hf-internal-testing/tiny-random-GPTJForCausalLM",
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]

# remove lwq tmp directory
from neural_compressor.torch.algorithms.layer_wise.utils import LWQ_WORKSPACE

shutil.rmtree(LWQ_WORKSPACE, ignore_errors=True)
assert torch.equal(
out, q_label
), f"use_layer_wise=True and quant_lm_head={quant_lm_head} output should be same. Please double check."

def test_true_sequential(self):
# true_sequential=False
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig(
true_sequential=False,
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
atol_false = (out - self.label).amax()
# true_sequential=True
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig(
true_sequential=True,
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
atol_true = (out - self.label).amax()
# compare atol, this case is an ideal case.
assert (
atol_false < atol_true
), "true_sequential=True doesn't help accuracy, maybe is reasonable, please double check."

def test_quant_lm_head(self):
# quant_lm_head=False
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig(
quant_lm_head=False,
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
assert torch.equal(out, q_label), "use_layer_wise=True output should be same. Please double check."
atol_false = (out - self.label).amax()
# quant_lm_head=True
model = copy.deepcopy(self.tiny_gptj)
quant_config = GPTQConfig(
quant_lm_head=True,
)
model = prepare(model, quant_config)
run_fn(model)
model = convert(model)
out = model(self.example_inputs)[0]
atol_true = (out - self.label).amax()
# compare atol, this case is an ideal case.
assert (
atol_false < atol_true
), "quant_lm_head=True doesn't help accuracy, maybe is reasonable, please double check."
assert get_woq_linear_num(model, "INCWeightOnlyLinear") == 31, "Incorrect number of INCWeightOnlyLinear modules"

@pytest.mark.parametrize("dtype", ["nf4", "int4"])
@pytest.mark.parametrize("double_quant_bits", [6])
Expand Down
13 changes: 11 additions & 2 deletions test/3x/torch/quantization/weight_only/test_rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def test_quant_lm_head(self):
), "The tied lm_head weight is not deep copied, please check!"

def test_layer_wise(self):
# use_layer_wise=False
model = copy.deepcopy(self.tiny_gptj)
quant_config = RTNConfig(
use_layer_wise=False,
)
model = prepare(model, quant_config)
model = convert(model)
out0 = model(self.example_inputs)[0]

from neural_compressor.torch import load_empty_model

model = load_empty_model("hf-internal-testing/tiny-random-GPTJForCausalLM")
Expand All @@ -182,8 +191,8 @@ def test_layer_wise(self):
)
model = prepare(model, quant_config)
model = convert(model)
out = model(self.example_inputs)[0]
assert torch.equal(out, self.q_label), "use_layer_wise=True output should be same. Please double check."
out1 = model(self.example_inputs)[0]
assert torch.equal(out1, out0), "use_layer_wise=True output should be same. Please double check."

@pytest.mark.parametrize(
"dtype",
Expand Down
Loading