Skip to content

Commit 3d11b5e

Browse files
authored
Support NF4/FP4 data type in weight-only (#1185)
* support NF4/FP4 data type in weight-only RTN & AWQ algo, allow tuning dtype and compressing nf4/fp4 mode Signed-off-by: Xin He <xin3.he@intel.com> --------- Signed-off-by: Xin He <xin3.he@intel.com>
1 parent ffe47d9 commit 3d11b5e

File tree

15 files changed

+258
-78
lines changed

15 files changed

+258
-78
lines changed

.azure-pipelines/scripts/codeScan/pyspelling/inc_dict.txt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2716,4 +2716,14 @@ xgb
27162716
xgboost
27172717
hpo
27182718
HPO
2719-
arange
2719+
arange
2720+
nf
2721+
Dettmers
2722+
Qlora
2723+
llms
2724+
NormalFloat
2725+
QLoRA
2726+
TimDettmers
2727+
bitsandbytes
2728+
bnb
2729+
ccedc

docs/source/quantization_weight_only.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,14 @@ There are many excellent works for weight only quantization to improve its accur
3535
### **Quantization Capability**:
3636
| Config | Capability |
3737
| :---: | :---:|
38+
| dtype | ['int', 'nf4', 'fp4'] |
3839
| bits | [1-8] |
3940
| group_size | [-1, 1-N] |
4041
| scheme | ['asym', 'sym'] |
4142
| algorithm | ['RTN', 'AWQ', 'GPTQ'] |
4243

44+
Notes: 4-bit NormalFloat(NF4) is proposed in QLoRA[5]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb.
45+
4346
**RTN arguments**:
4447
| rtn_args | default value | comments |
4548
|:----------:|:-------------:|:-------------------------------------------------------------------:|
@@ -95,7 +98,7 @@ conf = PostTrainingQuantConfig(
9598
},
9699
recipes={
97100
# 'gptq_args':{'percdamp': 0.01, 'actorder':True, 'block_size': 128, 'nsamples': 128, 'use_full_length': False},
98-
'awq_args':{'auto_scale': True, 'mse_range': True, 'n_blocks': 5},
101+
# 'awq_args':{'auto_scale': True, 'mse_range': True},
99102
},
100103
)
101104
q_model = quantization.fit(model, conf, eval_func=eval_func)
@@ -119,3 +122,5 @@ The saved_results folder contains two files: `best_model.pt` and `qconfig.json`,
119122
[3]. Lin, Ji, et al. "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." arXiv preprint arXiv:2306.00978 (2023).
120123

121124
[4]. Frantar, Elias, et al. "Gptq: Accurate post-training quantization for generative pre-trained transformers." arXiv preprint arXiv:2210.17323 (2022).
125+
126+
[5]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023).

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_clm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def eval_func_for_nc(model_tuned):
618618
if model_args.int8:
619619
from neural_compressor.utils.pytorch import load
620620
new_model = load(
621-
os.path.abspath(os.path.expanduser(training_args.output_dir)), model)
621+
os.path.abspath(os.path.expanduser(training_args.output_dir)), model, weight_only=True)
622622
else:
623623
new_model = model
624624

neural_compressor/adaptor/pytorch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4349,6 +4349,14 @@ def quantize(self, tune_cfg, model, dataloader, calib_func=None):
43494349
if config['weight']['dtype'] == 'fp32':
43504350
continue
43514351
else:
4352+
dtype = config['weight']['dtype']
4353+
if dtype in ['nf4', 'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1']:
4354+
config['weight']['bits'] = 4
4355+
config['weight']['scheme'] = 'sym'
4356+
elif dtype in ['int4']:
4357+
config['weight']['bits'] = 4
4358+
elif dtype in ['int8']:
4359+
config['weight']['bits'] = 8
43524360
algorithm = config['weight']['algorithm']
43534361
all_algo.add(algorithm)
43544362
if len(all_algo):
@@ -4385,15 +4393,17 @@ def rtn_quantize(self, model, tune_cfg):
43854393
if config['weight']['dtype'] == 'fp32':
43864394
continue
43874395
else:
4396+
dtype = config['weight']['dtype']
43884397
num_bits = config['weight']['bits']
4389-
group_size = config['weight']['group_size']
43904398
scheme = config['weight']['scheme']
4399+
group_size = config['weight']['group_size']
43914400
algorithm = config['weight']['algorithm']
43924401
if algorithm != 'RTN':
43934402
continue
43944403
m = fetch_module(model, op_name)
43954404
m = rtn_quantize(m, num_bits, group_size, scheme,
43964405
return_int=False,
4406+
data_type=dtype,
43974407
sym_full_range=sym_full_range,
43984408
mse_range=mse_range)
43994409
set_module(model, op_name, m)

neural_compressor/adaptor/pytorch_cpu.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@
262262
weight_only_integer: &cap_weight_only_integer {
263263
'Linear': &cap_weight_only_integer_linear { # only Linear now
264264
'weight': {
265-
'dtype': ['int'], # no need to care uint
265+
'dtype': ['int', 'int4', 'nf4', 'fp4', 'fp4_e2m1_bnb', 'fp4_e2m1'],
266266
'bits': [4, 1, 2, 3, 5, 6, 7, 8], # [1-8], # 4
267267
# group_size=-1 means per-channel, others means per-group
268268
'group_size': [32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], # [1-inf], # 32
@@ -273,7 +273,6 @@
273273
'dtype': ['fp32'],
274274
},
275275
},
276-
'Conv2d': *cap_weight_only_integer_linear,
277276
}
278277

279278

neural_compressor/adaptor/torch_utils/awq.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ def _get_act_scale(input_val):
9191
class ActAwareWeightQuant:
9292
"""Implementation of Activation-aware Weight quantization (AWQ) algo."""
9393
def __init__(self, model, example_inputs=None, calib_func=None, dataloader=None, n_samples=128,
94-
bits=4, group_size=32, scheme='asym', sym_full_range=False, weight_config={},):
94+
data_type='int', bits=4, group_size=32, scheme='asym', sym_full_range=False,
95+
weight_config={},):
9596
self.example_inputs = example_inputs
9697
if example_inputs is None:
9798
assert dataloader is not None, "datalaoder or example_inputs is required."
@@ -103,6 +104,7 @@ def __init__(self, model, example_inputs=None, calib_func=None, dataloader=None,
103104
# Step 2: get block list and block prefix, number
104105
self.block_prefix, self.block_num = get_block_prefix(model)
105106
self.block_list = fetch_module(model, self.block_prefix)
107+
self.data_type = data_type
106108
self.bits = bits
107109
self.group_size = group_size
108110
self.scheme = scheme
@@ -188,11 +190,13 @@ def search_scale(self, block, block_name, module_list, input_values):
188190
for module_tuple in module_list:
189191
# Step 1: Initailize quantization configuration.
190192
if module_tuple[0] in self.weight_config:
193+
cur_dtype = self.weight_config[module_tuple[0]]['dtype']
191194
cur_bits = self.weight_config[module_tuple[0]]['bits']
192195
cur_group_size = self.weight_config[module_tuple[0]]['group_size']
193196
cur_scheme = self.weight_config[module_tuple[0]]['scheme']
194197
else:
195-
cur_bits, cur_group_size, cur_scheme = self.bits, self.group_size, self.scheme
198+
cur_dtype, cur_bits, cur_group_size, cur_scheme = \
199+
self.data_type, self.bits, self.group_size, self.scheme
196200
if cur_bits < 0:
197201
continue
198202
logger.info(f"[SCALE] Processing module: {module_tuple}")
@@ -231,6 +235,7 @@ def search_scale(self, block, block_name, module_list, input_values):
231235
module.weight.data = module.weight.data.mul(scales.view(1, -1))
232236
module.weight.data = quant_weight(
233237
module.weight.data,
238+
data_type=cur_dtype,
234239
num_bits=cur_bits,
235240
group_size=cur_group_size,
236241
scheme=cur_scheme,
@@ -310,11 +315,13 @@ def search_clip(self, block_name, module_list, input_values):
310315
for module_name in module_tuple:
311316
# Step 1: Initailize quantization configuration.
312317
if module_name in self.weight_config:
318+
cur_dtype = self.weight_config[module_name]['dtype']
313319
cur_bits = self.weight_config[module_name]['bits']
314320
cur_group_size = self.weight_config[module_name]['group_size']
315321
cur_scheme = self.weight_config[module_name]['scheme']
316322
else:
317-
cur_bits, cur_group_size, cur_scheme = self.bits, self.group_size, self.scheme
323+
cur_dtype, cur_bits, cur_group_size, cur_scheme = \
324+
self.data_type, self.bits, self.group_size, self.scheme
318325
if cur_bits < 0:
319326
continue
320327
logger.info(f"[CLIP] Processing module: {module_name}")
@@ -335,6 +342,7 @@ def search_clip(self, block_name, module_list, input_values):
335342
# MulLinear can also work with @weight.setter
336343
module.weight.data = quant_weight(
337344
module.weight.data,
345+
data_type=cur_dtype,
338346
num_bits=cur_bits,
339347
group_size=cur_group_size,
340348
scheme=cur_scheme,

neural_compressor/adaptor/torch_utils/model_wrapper.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,18 @@ def _wrapper_qdq_linear(tmp_model, module_name_list=[]):
194194

195195
class WeightOnlyLinear(torch.nn.Module):
196196
def __init__(self, in_features, out_features, bits, groupsize,
197-
zp=False, bias=False, scale_dtype=torch.float32,
197+
dtype='int', zp=False, bias=False, scale_dtype=torch.float32,
198198
compression_dtype=torch.int32, compression_dim=1,
199199
gptq_perm=False, device='cpu'):
200200
super().__init__()
201+
self.dtype = dtype
202+
if 'int' not in self.dtype: # for nf4, fp4
203+
from neural_compressor.adaptor.torch_utils.weight_only import FLOAT_MAPPING, INT_MAPPING
204+
float_list = FLOAT_MAPPING[self.dtype]
205+
int_list = INT_MAPPING[self.dtype]
206+
self.int2float_mapping = {}
207+
for k, v in zip(int_list, float_list):
208+
self.int2float_mapping[k] = v
201209
self.device = device
202210
self.in_features = in_features
203211
self.out_features = out_features
@@ -346,6 +354,11 @@ def recover(self):
346354
weight[:, index] = tmp.type(weight_dtype)
347355
if self.compression_dim == 0:
348356
weight = weight.T
357+
if 'int' not in self.dtype:
358+
new_weight = torch.zeros(self.out_features, self.in_features).to(device)
359+
for k, v in self.int2float_mapping.items():
360+
new_weight += torch.where(weight == k, v, 0)
361+
weight = new_weight
349362
# unpack zero_point
350363
if hasattr(self, 'packed_zp'):
351364
zp_dtype = self.compressed_dtype # to avoid overflow when weight-zp

0 commit comments

Comments
 (0)