Skip to content

Commit 60700f4

Browse files
authored
Sample code of SignRound (#1313)
1 parent 7537830 commit 60700f4

File tree

6 files changed

+1201
-0
lines changed

6 files changed

+1201
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
This is a sample code for SignRound ([arxiv](https://arxiv.org/abs/2309.05516)), which currently only supports LlaMa, OPT, and BLOOM models. We will provide a unified API that will support a broader range of models in Intel Neural Compressor.
2+
3+
![overview](./overview.png)
4+
5+
6+
7+
# Prerequisite
8+
python 3.9 or higher
9+
10+
pip install -r requirements.txt
11+
12+
13+
# Run
14+
15+
```bash
16+
CUDA_VISIBLE_DEVICES=0 python3 signround.py --model_name facebook/opt-125m --amp --num_bits 4 --group_size -1 --seqlen 512
17+
```
18+
19+
To optimize GPU memory usage, you can enable the "low_gpu_mem_usage" option. Additionally, you can reduce the training batch size (train_bs) and increase the gradient_accumulate_steps accordingly.
20+
21+
```bash
22+
CUDA_VISIBLE_DEVICES=0 python3 signround.py --model_name facebook/opt-125m --amp --num_bits 4 --group_size -1 --seqlen 512 --low_gpu_mem_usage --train_bs 1 --gradient_accumulate_steps 8
23+
```
24+
## Known issue
25+
To address the original lambada evaluation bug in the old version of lm-eval, we have incorporated the lm-eval from intel extension for transformers(ITREX). This discrepancy may lead to certain variations.
26+
27+
To reproduce our results in the paper, please install ITREX
28+
29+
```bash
30+
pip install intel-extension-for-transformers
31+
```
32+
## Reference
33+
If you find SignRound useful or relevant to your research, please kindly cite our paper
34+
35+
```
36+
@article{cheng2023optimize,
37+
title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs},
38+
author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao},
39+
journal={arXiv preprint arXiv:2309.05516},
40+
year={2023}
41+
}
42+
```
43+
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import os.path
2+
import torch
3+
import torch.nn as nn
4+
5+
6+
def eval_model(model, model_name, tokenizer, tasks=["lambada_openai", "hellaswag", "winogrande", "piqa"], eval_bs=32):
7+
try:
8+
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate as lm_evaluate
9+
print("evaluation with itrex lm-eval", flush=True)
10+
11+
if str(model.device) == "cpu":
12+
model = model.to(torch.bfloat16)
13+
dtype = 'bfloat16'
14+
else:
15+
model = model.half()
16+
dtype = 'float16'
17+
model.eval()
18+
results = lm_evaluate(model="hf-causal",
19+
model_args=f'pretrained="{model_name}",tokenizer="{model_name}",dtype={dtype}',
20+
user_model=model,
21+
tasks=tasks,
22+
device=str(model.device),
23+
batch_size=eval_bs)
24+
25+
except:
26+
print("evaluation with official lm-eval", flush=True)
27+
from lm_eval.evaluator import simple_evaluate
28+
import json
29+
import shutil
30+
31+
##save model
32+
output_dir = "./tmp_signround"
33+
if os.path.exists(output_dir):
34+
shutil.rmtree(output_dir)
35+
if output_dir is not None:
36+
model.save_pretrained(output_dir)
37+
tokenizer.save_pretrained(output_dir)
38+
if str(model.device) == "cpu":
39+
dtype = 'bfloat16'
40+
else:
41+
dtype = 'float16'
42+
results = simple_evaluate(model="hf-causal",
43+
model_args=f'pretrained="{output_dir}",tokenizer="{output_dir}",dtype={dtype}',
44+
tasks=tasks,
45+
device=str(model.device),
46+
batch_size=eval_bs,
47+
no_cache=True)
48+
dumped = json.dumps(results, indent=2)
49+
print(dumped)
50+
51+
if os.path.exists(output_dir):
52+
shutil.rmtree(output_dir)
53+
54+
@torch.no_grad()
55+
def eval_same_with_gptq(model, testenc, dev):
56+
print('Evaluating ...', flush=True)
57+
# model.eval()
58+
model.to(dev)
59+
60+
testenc = testenc.input_ids
61+
nsamples = testenc.numel() // model.seqlen
62+
63+
use_cache = model.config.use_cache
64+
model.config.use_cache = False
65+
66+
testenc = testenc.to(dev)
67+
nlls = []
68+
for i in range(nsamples):
69+
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
70+
lm_logits = model(batch).logits
71+
shift_logits = lm_logits[:, :-1, :].contiguous()
72+
shift_labels = testenc[
73+
:, (i * model.seqlen):((i + 1) * model.seqlen)
74+
][:, 1:]
75+
loss_fct = nn.CrossEntropyLoss()
76+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
77+
neg_log_likelihood = loss.float() * model.seqlen
78+
nlls.append(neg_log_likelihood)
79+
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
80+
print(ppl.item())
81+
82+
model.config.use_cache = use_cache
83+
return ppl.item()
84+
85+
datasets = ['wikitext2', 'ptb-new', 'c4-new']
86+
87+
from gptq_data_loader import get_loaders
88+
for dataset in datasets:
89+
dataloader, testloader = get_loaders(
90+
dataset, seed=0, model=model_name, seqlen=model.seqlen
91+
)
92+
print(dataset, flush=True)
93+
ppl = eval_same_with_gptq(model, testloader, str(model.device))
94+
results.update({dataset: ppl})

0 commit comments

Comments
 (0)