Skip to content

Commit aa457a3

Browse files
authored
Add GPTQ-GPTJ examples (#1091)
Signed-off-by: YIYANGCAI <yiyang.cai@intel.com>
1 parent 684aef4 commit aa457a3

File tree

4 files changed

+533
-2
lines changed

4 files changed

+533
-2
lines changed

examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_weight_only/README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,23 @@ python run_clm.py \
3434
3535
or
3636
```bash
37-
sh run_quant.sh --topology=topology_name --input_model=model_name_or_path --weight_only_bits=8 --weight_only_group=-1 --weight_only_scheme=sym --weight_only_algorithm=RTN
37+
sh run_quant.sh --topology=gpt_j_wikitext_weight_only --input_model=EleutherAI/gpt-j-6B --weight_only_bits=8 --weight_only_group=-1 --weight_only_scheme=sym --weight_only_algorithm=RTN
3838
```
3939

4040
> NOTE
4141
>
4242
> `weight_only_bits`, `weight_only_group`, `weight_only_scheme`, and `weight_only_algorithm` can be modified by user. For details, please refer to [README](../../../../../../../docs/source/quantization_weight_only.md).
4343
44+
### Run MLPerf on GPT-J-6B
45+
Use the following link to get
46+
[**CNN Daily Mail** datasets](https://github.com/intel-innersource/frameworks.ai.benchmarking.mlperf.submission.inference-submission-v3-1/tree/master/closed/Intel/code/gpt-j/pytorch-cpu#download-and-prepare-dataset)
47+
and [gpt-j-6B mlperf model](https://github.com/mlcommons/inference/tree/master/language/gpt-j#download-gpt-j-model)
48+
49+
Then run following command to do quantization
50+
```shell
51+
sh run_gptj_mlperf_int4.sh
52+
```
53+
4454
## 2. Benchmark
4555
```bash
4656
# int8
@@ -59,7 +69,7 @@ sh run_benchmark.sh --topology=topology_name --mode=performance --input_model=mo
5969
</thead>
6070
<tbody align="center">
6171
<tr>
62-
<td>gpt_j_wikitext</td>
72+
<td>gpt_j_wikitext_weight_only</td>
6373
<td><a href="https://huggingface.co/EleutherAI/gpt-j-6B">EleutherAI/gpt-j-6B</a></td>
6474
<td><a href="https://huggingface.co/datasets/wikitext">wikitext</a></td>
6575
</tr>
@@ -102,6 +112,7 @@ from neural_compressor.utils.pytorch import load
102112
quantized_model = load(tuned_checkpoint, model)
103113
```
104114
--------
115+
105116
For more details, please refer to the [sample code](./run_clm.py).
106117

107118
# (May Remove Later) Run GPTQ algorithm
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import sys
2+
import argparse
3+
import os
4+
import time
5+
import json
6+
import fnmatch
7+
8+
import copy
9+
import logging
10+
from dataclasses import dataclass, field
11+
from typing import Optional, Dict, Sequence
12+
13+
14+
import numpy as np
15+
import torch
16+
import torch.nn.functional as F
17+
from datasets import load_dataset, load_from_disk
18+
from torch.nn.functional import pad
19+
from torch.utils.data import DataLoader
20+
import transformers
21+
from transformers import AutoModelForCausalLM, AutoTokenizer
22+
23+
import random
24+
random.seed(9973)
25+
26+
# Bucketize sequence lengths
27+
MaxLens = range(0,64,1919)
28+
Buckets = dict()
29+
cutoff_step = 64
30+
min_cutoff = 64
31+
min_len = 1
32+
for cutoff in range(min_cutoff, 1921, cutoff_step): # All input sequences
33+
Buckets[cutoff] = list(range(min_len, cutoff, 1))
34+
min_len = cutoff
35+
36+
#Buckets[1920] = list(range(min_len, 1921, 1))
37+
38+
input_buckets = dict()
39+
for cutoff, seq_lens in Buckets.items():
40+
for seq_len in seq_lens:
41+
input_buckets[seq_len] = cutoff
42+
43+
#print("Buckets: {}".format(input_buckets))
44+
45+
IGNORE_INDEX = -100
46+
DEFAULT_PAD_TOKEN = "[PAD]"
47+
DEFAULT_EOS_TOKEN = "</s>"
48+
DEFAULT_BOS_TOKEN = "</s>"
49+
DEFAULT_UNK_TOKEN = "</s>"
50+
PROMPT_DICT = {
51+
"prompt_input": (
52+
"Below is an instruction that describes a task, paired with an input that provides further context. "
53+
"Write a response that appropriately completes the request.\n\n"
54+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
55+
),
56+
"prompt_no_input": (
57+
"Below is an instruction that describes a task. "
58+
"Write a response that appropriately completes the request.\n\n"
59+
"### Instruction:\n{instruction}\n\n### Response:"
60+
),
61+
}
62+
63+
64+
class CNNDAILYMAIL(object):
65+
def __init__(self, model_path, data_path, device="cpu",is_calib=False, num_samples=20, max_len=1920):
66+
self.model_path = model_path
67+
self.data_path = data_path
68+
self.device = device
69+
self.num_samples = num_samples
70+
self.is_calib = is_calib
71+
72+
self.padding = "max_length" if self.is_calib else False
73+
self.max_len = 2048 if self.is_calib else max_len
74+
75+
self.calib_collator = self.collate_batch
76+
self.pad_max = max_len
77+
self.load_tokenizer()
78+
self.load_dataset()
79+
def load_dataset(self):
80+
""" Loads dataset"""
81+
with open(self.data_path, "r") as fid:
82+
list_data_dict = json.load(fid)
83+
self.list_data_dict = copy.deepcopy(list_data_dict)
84+
85+
if self.num_samples is not None:
86+
self.num_samples = min(self.num_samples, len(list_data_dict))
87+
88+
if self.is_calib:
89+
list_data_dict = list_data_dict[:self.num_samples]
90+
else:
91+
list_data_dict = random.choices(list_data_dict, k=self.num_samples)
92+
93+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
94+
sources = [prompt_input.format_map(example) for example in list_data_dict]
95+
targets = [f"{example['output']}" for example in list_data_dict]
96+
97+
self.input_ids = []
98+
self.input_lens = []
99+
for i in range(len(sources)):
100+
tok_input = self.tokenize_function(sources[i])
101+
self.input_ids.append(tok_input.input_ids)
102+
103+
104+
#if self.num_samples is not None:
105+
# self.num_samples = min(self.num_samples, len(list_data_dict))
106+
# self.input_ids = random.choices(self.input_ids, k=self.num_samples)
107+
# print("Sources: {}".format(len(sources)))
108+
# print("Targets: {}".format(len(targets)))
109+
# sources = random.choices(sources, k=self.num_samples)
110+
# targets = random.choices(targets, k=self.num_samples)
111+
112+
113+
self.sources = sources
114+
self.targets = targets
115+
116+
def load_tokenizer(self):
117+
""" Returns the tokenizer """
118+
self.tokenizer = AutoTokenizer.from_pretrained(
119+
self.model_path,
120+
model_max_length=2048,
121+
padding_side="right",
122+
use_fast=False,
123+
)
124+
self.tokenizer.pad_token = self.tokenizer.eos_token
125+
126+
@torch.no_grad()
127+
def tokenize_function(self, text):
128+
example = self.tokenizer(text, truncation=True, max_length=self.max_len, return_tensors="pt", padding=self.padding)
129+
return example
130+
131+
def __len__(self):
132+
return len(self.input_ids)
133+
134+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
135+
input_ids = self.input_ids[i]
136+
input_len = input_ids.shape[-1]
137+
#pad_size = input_buckets[input_len] - input_len
138+
#input_ids = F.pad(input_ids, pad=(0, pad_size))
139+
return (input_ids, input_len)
140+
141+
@torch.no_grad()
142+
def collate_batch(self, batch):
143+
input_ids_padded = []
144+
145+
for input_ids, input_lens in batch: # input_ids are returned by this dataset (see __getitem__)
146+
pad_len = self.pad_max - input_ids.shape[0]
147+
#input_ids = F.pad(input_ids, pad=(0, pad_size), value=self.tokenizer.pad_token_id)
148+
input_ids_padded.append(input_ids)
149+
150+
input_ids_padded = torch.vstack(input_ids_padded)
151+
return (input_ids_padded, input_ids_padded)
152+
153+
def get_warmup_samples(self):
154+
cutoff_set = set(range(128, 1920, 64))
155+
warmup_samples = []
156+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
157+
sources = [prompt_input.format_map(example) for example in self.list_data_dict]
158+
for source in sources: #self.input_ids:
159+
tok_input = self.tokenize_function(source)
160+
input_ids = tok_input.input_ids
161+
input_len = input_ids.shape[-1]
162+
bucket = input_buckets[input_len]
163+
if bucket in cutoff_set:
164+
#print("inputlen: {}; Bucket: {}".format(input_len, bucket))
165+
pad_size = bucket - input_len
166+
input_ids = F.pad(input_ids, pad=(0, pad_size), value=0)
167+
warmup_samples.append(input_ids)
168+
cutoff_set.remove(bucket)
169+
if len(cutoff_set)==0:
170+
break
171+
172+
return warmup_samples

0 commit comments

Comments
 (0)