Skip to content

Commit 4728fdc

Browse files
authored
Support auto_round integration 2.x (#1806)
Signed-off-by: Kaihui-intel <kaihui.tang@intel.com>
1 parent 24508d0 commit 4728fdc

File tree

5 files changed

+85
-74
lines changed

5 files changed

+85
-74
lines changed

.azure-pipelines/scripts/ut/env_setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ elif [[ $(echo "${test_case}" | grep -c "tf pruning") != 0 ]]; then
9292
fi
9393

9494
if [[ $(echo "${test_case}" | grep -c "api") != 0 ]] || [[ $(echo "${test_case}" | grep -c "adaptor") != 0 ]]; then
95-
pip install auto-round
95+
pip install git+https://github.com/intel/auto-round.git@ecca5349981044e1278773a251b3fc5c0a11fe7b
9696
fi
9797

9898
# test deps

neural_compressor/adaptor/pytorch.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4916,12 +4916,12 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
49164916
weight_config[op_name]["sym"] = config["weight"]["scheme"] == "sym"
49174917

49184918
# auto round recipes
4919+
49194920
enable_full_range = self.recipes["autoround_args"].get("enable_full_range", False)
49204921
batch_size = self.recipes["autoround_args"].get("batch_size", 8)
49214922
lr_scheduler = self.recipes["autoround_args"].get("lr_scheduler", None)
4922-
dataset_name = self.recipes["autoround_args"].get("dataset_name", "NeelNanda/pile-10k")
4923-
dataset_split = self.recipes["autoround_args"].get("dataset_split", "train")
4924-
use_quant_input = self.recipes["autoround_args"].get("use_quant_input", True)
4923+
dataset = self.recipes["autoround_args"].get("dataset", "NeelNanda/pile-10k")
4924+
enable_quanted_input = self.recipes["autoround_args"].get("enable_quanted_input", True)
49254925
enable_minmax_tuning = self.recipes["autoround_args"].get("enable_minmax_tuning", True)
49264926
lr = self.recipes["autoround_args"].get("lr", None)
49274927
minmax_lr = self.recipes["autoround_args"].get("minmax_lr", None)
@@ -4938,22 +4938,26 @@ def autoround_quantize(self, model, tune_cfg, dataloader):
49384938
data_type = self.recipes["autoround_args"].get("data_type", "int") ##only support data_type
49394939
scale_dtype = self.recipes["autoround_args"].get("scale_dtype", "fp16")
49404940
amp = self.recipes["autoround_args"].get("amp", True)
4941+
device = self.recipes["autoround_args"].get("device", None)
4942+
bits = self.recipes["autoround_args"].get("bits", 4)
4943+
group_size = self.recipes["autoround_args"].get("group_size", 128)
4944+
sym = self.recipes["autoround_args"].get("scheme", "asym") == "sym"
49414945

4946+
if dataloader is not None:
4947+
dataset = dataloader
49424948
model, autoround_config = autoround_quantize(
49434949
model=model,
4944-
tokenizer=None,
4945-
bits=4,
4946-
group_size=128,
4947-
sym=False,
4950+
bits=bits,
4951+
group_size=group_size,
4952+
sym=sym,
49484953
weight_config=weight_config,
49494954
enable_full_range=enable_full_range,
49504955
batch_size=batch_size,
49514956
amp=amp,
4957+
device=device,
49524958
lr_scheduler=lr_scheduler,
4953-
dataloader=dataloader,
4954-
dataset_name=dataset_name,
4955-
dataset_split=dataset_split,
4956-
use_quant_input=use_quant_input,
4959+
dataset=dataset,
4960+
enable_quanted_input=enable_quanted_input,
49574961
enable_minmax_tuning=enable_minmax_tuning,
49584962
lr=lr,
49594963
minmax_lr=minmax_lr,

neural_compressor/adaptor/torch_utils/auto_round.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,27 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from auto_round.calib_dataset import CALIB_DATASETS # pylint: disable=E0401
1615

16+
def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=512):
17+
"""Generate a DataLoader for calibration using specified parameters.
18+
19+
Args:
20+
tokenizer (Tokenizer): The tokenizer to use for tokenization.
21+
seqlen (int): The exact sequence length. samples < seqlen will be dropped,
22+
samples longer than seqlen will be truncated
23+
dataset_name (str, optional): The name of the dataset or datasets separated by commas.
24+
Defaults to "NeelNanda/pile-10k".
25+
split (str, optional): The data split to use. Defaults to None.
26+
seed (int, optional): The random seed for reproducibility. Defaults to 42.
27+
bs (int, optional): The batch size. Defaults to 4.
28+
n_samples (int, optional): The total number of samples to include. Defaults to 512.
29+
30+
Returns:
31+
DataLoader: The DataLoader for the calibrated dataset.
32+
"""
33+
from auto_round.calib_dataset import get_dataloader # pylint: disable=E0401
1734

18-
def get_dataloader(
19-
tokenizer, seqlen=2048, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k"
20-
):
21-
get_dataloader = CALIB_DATASETS.get(dataset_name, CALIB_DATASETS["NeelNanda/pile-10k"])
2235
dataloader = get_dataloader(
23-
tokenizer, seqlen=seqlen, seed=seed, bs=train_bs, split=dataset_split, dataset_name=dataset_name
36+
tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=seed, bs=bs, n_samples=n_samples
2437
)
2538
return dataloader

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import math
2222
from copy import deepcopy
23-
from typing import OrderedDict
23+
from typing import Optional, OrderedDict, Union
2424

2525
from ...utils import logger
2626
from ...utils.utility import LazyImport
@@ -679,7 +679,7 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1):
679679

680680
def autoround_quantize(
681681
model,
682-
tokenizer,
682+
tokenizer=None,
683683
bits: int = 4,
684684
group_size: int = 128,
685685
sym: bool = False,
@@ -689,10 +689,8 @@ def autoround_quantize(
689689
amp: bool = True,
690690
device=None,
691691
lr_scheduler=None,
692-
dataloader=None, ## to support later
693-
dataset_name: str = "NeelNanda/pile-10k",
694-
dataset_split: str = "train",
695-
use_quant_input: bool = True,
692+
dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k",
693+
enable_quanted_input: bool = True,
696694
enable_minmax_tuning: bool = True,
697695
lr: float = None,
698696
minmax_lr: float = None,
@@ -706,52 +704,52 @@ def autoround_quantize(
706704
gradient_accumulate_steps: int = 1,
707705
not_use_best_mse: bool = False,
708706
dynamic_max_gap: int = -1,
709-
data_type: str = "int", ##only support data_type
710-
scale_dtype="fp16",
707+
data_type: str = "int", ##only support int for now
708+
scale_dtype: str = "fp16",
711709
**kwargs,
712710
):
713711
"""Run autoround weight-only quantization.
714712
Args:
715-
model: The PyTorch model to be quantized.
716-
tokenizer: Tokenizer for processing input data. Temporarily set as a mandatory parameter.
717-
bits (int): Number of bits for quantization (default is 4).
718-
group_size (int): Size of the quantization group (default is 128).
719-
sym (bool): Whether the symmetric quantization is to be used.
720-
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
721-
weight_config={
722-
'layer1':##layer_name
723-
{
724-
'data_type': 'int',
725-
'bits': 4,
726-
'group_size': 32,
727-
'scheme': "asym", ## or sym
728-
}
729-
...
730-
}
731-
enable_full_range (bool): Whether to enable full range quantization (default is False).
732-
bs (int): Batch size for training (default is 8).
733-
amp (bool): Whether to use automatic mixed precision (default is True). Automatically detect and set.
734-
device: The device to be used for tuning (default is None). Automatically detect and set.
735-
lr_scheduler: The learning rate scheduler to be used.
736-
dataloader: The dataloader for input data (to be supported in future).
737-
dataset_name (str): The default dataset name (default is "NeelNanda/pile-10k").
738-
dataset_split (str): The split of the dataset to be used (default is "train").
739-
use_quant_input (bool): Whether to use quantized input data (default is True).
740-
enable_minmax_tuning (bool): Whether to enable min-max tuning (default is True).
741-
lr (float): The learning rate (default is 0.005).
742-
minmax_lr (float): The learning rate for min-max tuning (default is None).
743-
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
744-
iters (int): Number of iterations (default is 200).
745-
seqlen (int): Length of the sequence.
746-
n_samples (int): Number of samples (default is 512).
747-
sampler (str): The sampling method (default is "rand").
748-
seed (int): The random seed (default is 42).
749-
n_blocks (int): Number of blocks (default is 1).
750-
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
751-
not_use_best_mse (bool): Whether to use mean squared error (default is False).
752-
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
753-
data_type (str): The data type to be used (default is "int").
754-
**kwargs: Additional keyword arguments.
713+
model: The PyTorch model to be quantized.
714+
tokenizer: An optional tokenizer for processing input data. If none is provided, a dataloader must be supplied.
715+
bits (int): Number of bits for quantization (default is 4).
716+
group_size (int): Size of the quantization group (default is 128).
717+
sym (bool): Whether symmetric quantization is to be used (default is False).
718+
weight_config (dict): Configuration for weight quantization (default is an empty dictionary).
719+
weight_config={
720+
'layer1':##layer_name
721+
{
722+
'data_type': 'int',
723+
'bits': 4,
724+
'group_size': 32,
725+
'sym': False
726+
}
727+
...
728+
}
729+
enable_full_range (bool): Whether to enable full range quantization (default is False).
730+
batch_size (int): Batch size for training (default is 8).
731+
amp (bool): Whether to use automatic mixed precision (default is True).
732+
device: The device to be used for tuning (default is "auto").
733+
lr_scheduler: The learning rate scheduler to be used.
734+
dataset (str): The default dataset name (default is "NeelNanda/pile-10k").
735+
enable_quanted_input (bool): Whether to use the output of the previous quantized block as
736+
the input for the current block (default is True).
737+
enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).
738+
lr (float): The learning rate (default is None, will be set to 1.0/iters).
739+
minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically).
740+
low_gpu_mem_usage (bool): Whether to use low GPU memory (default is True).
741+
iters (int): Number of iterations (default is 200).
742+
seqlen (int): Data length of the sequence for tuning (default is 2048).
743+
n_samples (int): Number of samples (default is 512).
744+
sampler (str): The sampling method (default is "rand").
745+
seed (int): The random seed (default is 42).
746+
n_blocks (int): Number of blocks (default is 1).
747+
gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).
748+
not_use_best_mse (bool): Whether to use mean squared error (default is False).
749+
dynamic_max_gap (int): The dynamic maximum gap (default is -1).
750+
data_type (str): The data type to be used (default is "int").
751+
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
752+
have different choices.
755753
756754
Returns:
757755
The quantized model.
@@ -770,10 +768,8 @@ def autoround_quantize(
770768
amp=amp,
771769
device=device,
772770
lr_scheduler=lr_scheduler,
773-
dataloader=dataloader, ## to support later
774-
dataset_name=dataset_name,
775-
dataset_split=dataset_split,
776-
use_quant_input=use_quant_input,
771+
dataset=dataset,
772+
enable_quanted_input=enable_quanted_input,
777773
enable_minmax_tuning=enable_minmax_tuning,
778774
lr=lr,
779775
minmax_lr=minmax_lr,

test/adaptor/pytorch_adaptor/test_weight_only_adaptor_pytorch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,7 @@ def test_AutoRound_quant(self):
752752
tokenizer = transformers.AutoTokenizer.from_pretrained(
753753
"hf-internal-testing/tiny-random-GPTJForCausalLM", trust_remote_code=True
754754
)
755-
dataloader = get_dataloader(
756-
tokenizer, seqlen=10, seed=42, train_bs=8, dataset_split="train", dataset_name="NeelNanda/pile-10k"
757-
)
755+
dataloader = get_dataloader(tokenizer, 32, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, n_samples=20)
758756
fp32_model = copy.deepcopy(self.gptj)
759757
conf = PostTrainingQuantConfig(
760758
approach="weight_only",
@@ -777,7 +775,7 @@ def test_AutoRound_quant(self):
777775
recipes={
778776
"autoround_args": {
779777
"n_samples": 20,
780-
"seq_len": 10,
778+
"seqlen": 32,
781779
"iters": 10,
782780
"scale_dtype": "fp32",
783781
"amp": False,

0 commit comments

Comments
 (0)