Skip to content

Commit 22daf80

Browse files
Update gptj example with the newest GPTQ API. (#1277)
Signed-off-by: YIYANGCAI <yiyang.cai@intel.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 550cee2 commit 22daf80

File tree

4 files changed

+81
-39
lines changed

4 files changed

+81
-39
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/README.md

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ sh run_quant.sh --topology=gpt_j_wikitext_weight_only --input_model=EleutherAI/g
3737
>
3838
> `weight_only_bits`, `weight_only_group`, `weight_only_scheme`, and `weight_only_algorithm` can be modified by user. For details, please refer to [README](../../../../../../../docs/source/quantization_weight_only.md).
3939
40-
### Run MLPerf on GPT-J-6B
40+
### Run MLPerf on GPT-J-6B using GPTQ
4141
Use the following link to get
4242
[**CNN Daily Mail** datasets](https://github.com/intel-innersource/frameworks.ai.benchmarking.mlperf.submission.inference-submission-v3-1/tree/master/closed/Intel/code/gpt-j/pytorch-cpu#download-and-prepare-dataset)
4343
and [gpt-j-6B mlperf model](https://github.com/mlcommons/inference/tree/master/language/gpt-j#download-gpt-j-model)
@@ -54,11 +54,28 @@ python -u examples/pytorch/nlp/huggingface_models/language-modeling/quantization
5454
--val-data-path /your/data/validation-data/cnn_dailymail_validation.json \
5555
--calib-iters 128 \
5656
--use_max_length \
57-
--use_fp16 \
57+
--pad_max_length 2048 \
5858
--use_gpu
5959
```
6060
Notes: for per channel quantization, set group_size to **-1**, otherwise 32, 128, etc. More comprehensive details about user-defined arguments are available at our [weight_onlly quantization documentations](https://github.com/intel/neural-compressor/blob/master/docs/source/quantization_weight_only.md#quantization-capability)
6161

62+
### Run general examples on a wide variety of LLMs using GPTQ
63+
We also support GPTQ algorithm on various language models (OPTs, Blooms, LLaMAs, MPTs, Falcons, ChatGLMs, etc.) in a generalized code. Please refer to script *run-gptq-llm.py* for more information.
64+
65+
You can simply use following command to do quantization (please refer to *run-gptq-llm.sh*).
66+
```shell
67+
python examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run-gptq-llm.py \
68+
--model_name_or_path facebook/opt-125m \
69+
--weight_only_algo GPTQ \
70+
--dataset NeelNanda/pile-10k \
71+
--wbits 4 \
72+
--group_size 128 \
73+
--pad_max_length 2048 \
74+
--use_max_length \
75+
--seed 0 \
76+
--gpu
77+
```
78+
6279
## 2. Benchmark
6380
```bash
6481
# int8

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def skip(*args, **kwargs):
3131
torch.nn.init.uniform_ = skip
3232
torch.nn.init.normal_ = skip
3333
from transformers import GPTJForCausalLM, AutoModelForCausalLM
34-
model = GPTJForCausalLM.from_pretrained(model) # load the model with fp32 precision
35-
#model = AutoModelForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
34+
# model = GPTJForCausalLM.from_pretrained(model) # load the model with fp32 precision
35+
model = GPTJForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16)
3636
return model
3737

3838
def postprocess_text(preds, targets):
@@ -90,7 +90,6 @@ def sync():
9090
predictions = []
9191
ground_truths = []
9292

93-
# import pdb;pdb.set_trace()
9493
with torch.no_grad(), torch.inference_mode():
9594
times = []
9695
#for i, (input_ids, labels) in enumerate(benchmark_dataset):# in range(input_ids.numel()):
@@ -266,7 +265,6 @@ def forward(self, *inp, **kwargs):
266265
args = parser.parse_args()
267266
# method 1: directly import AutoModelForCausalLM
268267
model = get_gptj(args.model_name_or_path)
269-
model.seqlen = args.pad_max_length
270268
model.eval()
271269

272270
if args.use_gpu and torch.cuda.is_available():
@@ -288,15 +286,19 @@ def forward(self, *inp, **kwargs):
288286

289287
# # do the quantization
290288
print('Starting ...')
289+
if args.sym:
290+
sym_opt = 'sym'
291+
else:
292+
sym_opt = 'asym'
291293

292294
conf = PostTrainingQuantConfig(
293295
approach='weight_only',
294296
op_type_dict={
295297
'.*':{ # re.match
296298
"weight": {
297-
'bits': 4, # 1-8 bits
298-
'group_size': 128, # -1 (per-channel)
299-
'scheme': 'sym',
299+
'bits': args.wbits, # 1-8 bits
300+
'group_size': args.group_size, # -1 (per-channel)
301+
'scheme': sym_opt,
300302
'algorithm': 'GPTQ',
301303
},
302304
},
@@ -314,7 +316,8 @@ def forward(self, *inp, **kwargs):
314316
'act_order':args.act_order,
315317
'block_size': args.block_size,
316318
'nsampeles': args.nsamples,
317-
'use_max_length': args.use_max_length
319+
'use_max_length': args.use_max_length,
320+
'pad_max_length': args.pad_max_length
318321
},
319322
},
320323
)

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/run_gptj_mlperf_int4.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ python -u examples/pytorch/nlp/huggingface_models/language-modeling/quantization
1212
--val-data-path ${VALIDATION_DATA} \
1313
--calib-iters 128 \
1414
--use_max_length \
15-
--use_fp16 \
15+
--pad_max_length 2048 \
1616
--use_gpu

neural_compressor/adaptor/torch_utils/gptq.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,11 @@ def prepare_dataloader(self):
236236
# general selection, no padding, not GPTQ original implementation.
237237
self.obtain_first_n_samples()
238238
try:
239-
self.inp = [torch.zeros(1) for _ in range(len(self.dataloader))]
240-
self.cache = {"i": 0} # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.)
239+
self.cache_key_arguments = {
240+
"i": 0
241+
} # a dict of list, keyword arguments ("attention_masks", "position_ids", etc.)
242+
# Note that the first elements in cache_positional_arguments is main input: hidden_states
241243
self.cache_positional_arguments = [] # a list of list, positional arguments ("rotary_pos_emb" in chatglm)
242-
self.out = [torch.zeros(1) for _ in range(len(self.dataloader))]
243244
self.is_ready = True
244245
except:
245246
logger.warning("GPTQ Quantizer initialization failed!")
@@ -259,9 +260,14 @@ def obtain_first_n_samples(self, seed=0):
259260
if batch[0].shape[-1] > self.pad_max_length:
260261
i = random.randint(0, batch[0].shape[-1] - self.pad_max_length - 1)
261262
j = i + self.pad_max_length
262-
batch_final = batch[0][:, i:j]
263+
batch_final = []
264+
for item in batch:
265+
if isinstance(item, torch.Tensor) and item.shape.__len__() == 2:
266+
batch_final.append(item[:, i:j])
267+
else:
268+
batch_final.append(item)
263269
else:
264-
batch_final = batch[0]
270+
batch_final = batch[:]
265271
# dict
266272
elif isinstance(batch, dict):
267273
try:
@@ -302,18 +308,22 @@ def obtain_first_n_samples_fulllength(self, seed=0):
302308
if len(self.dataloader) == self.nsamples:
303309
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
304310
break
305-
# list & tuple
311+
# list & tuple, gpt-j-6b mlperf, etc.
306312
if isinstance(batch, list) or isinstance(batch, tuple):
307313
if batch[0].shape[-1] == unified_length:
308-
batch_final = batch[0]
314+
batch_final = batch[:]
309315
elif batch[0].shape[-1] > unified_length:
310316
i = random.randint(0, batch[0].shape[-1] - unified_length - 1)
311317
j = i + unified_length
312-
batch_final = batch[0][:, i:j]
318+
batch_final = []
319+
for item in batch:
320+
if isinstance(item, torch.Tensor) and item.shape.__len__() == 2:
321+
batch_final.append(item[:, i:j])
322+
else:
323+
batch_final.append(item)
313324
else:
314325
# not match max length, not include in target dataset
315326
continue
316-
self.dataloader.append(batch_final)
317327
# dict
318328
elif isinstance(batch, dict):
319329
try:
@@ -406,17 +416,16 @@ def pre_quantization(self):
406416
"""Prepare input calibration data and other attributes which are critical for gptq execution."""
407417

408418
# critical: hooker function which collects inputs
409-
def forward(layer, hidden_states, *args, **kwargs):
419+
def forward(layer, *args, **kwargs):
410420
# inputs[inputs_info['idx']] = input_ids # TODO solve the problem of batchsize!=1
411-
self.inp[self.cache["i"]] = hidden_states
412-
self.cache["i"] += 1
421+
self.cache_key_arguments["i"] += 1
413422
for arg in kwargs:
414423
# TODO: investigate include parameters
415424
# each outputs can be different shape, hence also use list to store
416425
if isinstance(kwargs[arg], torch.Tensor) or arg == "alibi":
417-
if self.cache.get(arg, None) is None:
418-
self.cache[arg] = []
419-
self.cache[arg].append(kwargs[arg])
426+
if self.cache_key_arguments.get(arg, None) is None:
427+
self.cache_key_arguments[arg] = []
428+
self.cache_key_arguments[arg].append(kwargs[arg])
420429
continue
421430
# copy positional arguments, positional arguments are sensitive for their order, be cautious!
422431
# Most models in HF has avoid this, but some models still use positional arguments other than
@@ -454,8 +463,12 @@ def forward(layer, hidden_states, *args, **kwargs):
454463
pass
455464
# output inp data shape
456465
logger.info("All calibration data's shape =>")
457-
for idx in range(len(self.dataloader)):
458-
logger.info(self.inp[idx].shape)
466+
# check all hidden_states shape
467+
try:
468+
for hidden_states in self.cache_positional_arguments[0]:
469+
logger.info(hidden_states.shape)
470+
except:
471+
pass
459472
logger.info("Done.")
460473

461474
# Step 4: restore original forward function, relocate layers back to cpu.
@@ -481,12 +494,20 @@ def gather_single_batch_from_list(self, data_list, idx):
481494
single_batch.append(data_item[idx])
482495
return single_batch
483496

497+
def update_blockwise_hidden_states(self, outs):
498+
if "hidden_states" in self.cache_key_arguments:
499+
self.cache_key_arguments["hidden_states"] = outs[:]
500+
else:
501+
self.cache_positional_arguments[0] = outs[:]
502+
484503
@torch.no_grad()
485504
def execute_quantization(self, means=None, stds=None):
486505
"""Run quantization."""
487506
# Step1: prepare quantization (calibration datasets)
507+
488508
logger.info("Begin ====>")
489509
self.pre_quantization()
510+
490511
# Step2: run gptq quantization in a transformer block-wise manner.
491512
gptq_config = {}
492513
tblock_length = len(self.gptq_related_blocks["transformers"])
@@ -533,13 +554,13 @@ def tmp(_, inp, out):
533554
handles = [] # register handles which add inputs and outputs to gptq object
534555
for layer_name in sub_layers:
535556
handles.append(sub_layers[layer_name].register_forward_hook(add_batch(layer_name)))
536-
idx = self.cache.pop("i")
557+
idx = self.cache_key_arguments.pop("i")
558+
# import pdb;pdb.set_trace()
537559
for j in range(len(self.dataloader)):
538-
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
539-
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
560+
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
540561
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
541-
self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0]
542-
self.cache["i"] = idx
562+
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0]
563+
self.cache_key_arguments["i"] = idx
543564
for h in handles:
544565
h.remove()
545566
# Step 2.4: everything is prepared, so start quantization!
@@ -565,18 +586,19 @@ def tmp(_, inp, out):
565586
gptq_for_this_block[layer_name].free()
566587

567588
# Step 2.5: replace output data with quantized weights
568-
idx = self.cache.pop("i")
589+
outs = []
590+
idx = self.cache_key_arguments.pop("i")
569591
for j in range(len(self.dataloader)):
570-
# self.inp[j] shape: [1, seq_len, hidden_size] (batchsize is 1 by default)
571-
cache_batch = self.gather_single_batch_from_dict(self.cache, j)
592+
cache_keyword_batch = self.gather_single_batch_from_dict(self.cache_key_arguments, j)
572593
cache_positional_batch = self.gather_single_batch_from_list(self.cache_positional_arguments, j)
573-
self.out[j] = transformer_block(self.inp[j], *cache_positional_batch, **cache_batch)[0]
574-
self.cache["i"] = idx
594+
out = transformer_block(*cache_positional_batch, **cache_keyword_batch)[0]
595+
outs.append(out)
596+
self.cache_key_arguments["i"] = idx
575597
self.gptq_related_blocks["transformers"][block_idx] = transformer_block.cpu()
576598
del gptq_for_this_block
577599
torch.cuda.empty_cache()
578600
# iteratively replace the input with output, thus layerwise quantization can continue.
579-
self.inp, self.out = self.out, self.inp
601+
self.update_blockwise_hidden_states(outs)
580602
logger.info("------------------------------")
581603

582604
logger.info("Quantization done")

0 commit comments

Comments
 (0)