Skip to content

Commit

Permalink
Merge pull request #5 from cylinbao/opt
Browse files Browse the repository at this point in the history
Adding OPT support for the simulated quantization.
  • Loading branch information
happierpig authored Jan 28, 2024
2 parents 3ced20c + cc02297 commit e89479c
Show file tree
Hide file tree
Showing 16 changed files with 991 additions and 176 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ To reproduce end-to-end throughput and latency evaluation, please check [e2e/REA
### Perplexity
* Atom achieves strong perplexity results across WikiText2, PTB and C4 datasets across on Llama models family.
![perplexity](figures/atom_ppl.png)
* Below is Atom's WikiText2 perplexity of OPT, comparing with SmoothQuant and OmniQuant. Note that for OPT-66B, Atom's result is without using GPTQ optimization.

|#Bit|Method|OPT-6.7B|OPT-13B|OPT-30B|OPT-66B|
|-|-|-|-|-|-|
|FP16|-|10.86|10.13|9.56|9.34|
|W4A4|SmoothQ|1.80E+04|7.40E+03|1.20E+04|2.20E+05|
|W4A4|OmniQ|12.24|11.65|10.6|10.29|
|W4A4|Atom|11.23|10.44|9.70|9.57|

### End-to-end throughput and latency
* Atom achieves up to 7.7x higher throughput with similar latency than `FP16` with a fixed GPU memory under serving scenario.
![e2e](figures/atom_e2e_eval.png)
Expand Down
28 changes: 16 additions & 12 deletions model/LMClass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@ def __init__(self, args, model=None):
# We use default dtype float16
config.torch_dtype = torch.float16

# Fix for transformer 4.28.0.dev0 compatibility
# See: https://github.com/Vahe1994/SpQR/blob/main/datautils.py#L164
from transformers import LlamaTokenizer
self.tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False)
if self.tokenizer.bos_token_id != 1 or self.tokenizer.eos_token_id != 2:
try:
self.tokenizer.bos_token_id = 1
self.tokenizer.eos_token_id = 2
print(f"bos/eos tokens updated: {self.tokenizer.bos_token_id=}, {self.tokenizer.eos_token_id=}")
except AttributeError:
pass
print(f"bos/eos tokens unchanged: {self.tokenizer.bos_token_id=}, {self.tokenizer.eos_token_id=}")
if "llama" in args.model.lower():
# Fix for transformer 4.28.0.dev0 compatibility
# See: https://github.com/Vahe1994/SpQR/blob/main/datautils.py#L164
from transformers import LlamaTokenizer
self.tokenizer = LlamaTokenizer.from_pretrained(args.model, use_fast=False)
if self.tokenizer.bos_token_id != 1 or self.tokenizer.eos_token_id != 2:
try:
self.tokenizer.bos_token_id = 1
self.tokenizer.eos_token_id = 2
print(f"bos/eos tokens updated: {self.tokenizer.bos_token_id=}, {self.tokenizer.eos_token_id=}")
except AttributeError:
pass
print(f"bos/eos tokens unchanged: {self.tokenizer.bos_token_id=}, {self.tokenizer.eos_token_id=}")
else:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False, legacy=False)

if model != None:
self.model = model
Expand Down
32 changes: 18 additions & 14 deletions model/datautils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,20 +148,24 @@ def __init__(self, input_ids):
def get_loaders(
name, nsamples=128, seed=0, seqlen=2048, model=''
):
assert "llama" in model.lower(), "Only llama models are supported."

from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False)
# Fix for transformer 4.28.0.dev0 compatibility
# See: https://github.com/Vahe1994/SpQR/blob/main/datautils.py#L164
if tokenizer.bos_token_id != 1 or tokenizer.eos_token_id != 2:
try:
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
print(f"bos/eos tokens updated: {tokenizer.bos_token_id=}, {tokenizer.eos_token_id=}")
except AttributeError:
pass
print(f"bos/eos tokens unchanged: {tokenizer.bos_token_id=}, {tokenizer.eos_token_id=}")
# assert "llama" in model.lower(), "Only llama models are supported."

if "llama" in model.lower():
from transformers import LlamaTokenizer
tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False)
# Fix for transformer 4.28.0.dev0 compatibility
# See: https://github.com/Vahe1994/SpQR/blob/main/datautils.py#L164
if tokenizer.bos_token_id != 1 or tokenizer.eos_token_id != 2:
try:
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
print(f"bos/eos tokens updated: {tokenizer.bos_token_id=}, {tokenizer.eos_token_id=}")
except AttributeError:
pass
print(f"bos/eos tokens unchanged: {tokenizer.bos_token_id=}, {tokenizer.eos_token_id=}")
else:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, legacy=False)

if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, model, tokenizer)
Expand Down
93 changes: 93 additions & 0 deletions model/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,97 @@ def forward(self, inp, **kwargs):
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

return ppl.item()

@torch.no_grad()
def opt_eval(model, testenc, dev):
print('Evaluating ...')

testenc = testenc.input_ids
nsamples = testenc.numel() // model.seqlen

use_cache = model.config.use_cache
model.config.use_cache = False
layers = model.model.decoder.layers

model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
layers[0] = layers[0].to(dev)

dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
(nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
)
cache = {'i': 0, 'attention_mask': None}

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps[cache['i']] = inp
cache['i'] += 1
cache['attention_mask'] = kwargs['attention_mask']
raise ValueError
layers[0] = Catcher(layers[0])
for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
try:
model(batch)
except ValueError:
pass
layers[0] = layers[0].module

layers[0] = layers[0].cpu()
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
model.model.decoder.project_out = model.model.decoder.project_out.cpu()
if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
model.model.decoder.project_in = model.model.decoder.project_in.cpu()
torch.cuda.empty_cache()

outs = torch.zeros_like(inps)
attention_mask = cache['attention_mask']

for i in tqdm(range(len(layers))):
layer = layers[i].to(dev)
for j in range(nsamples):
outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
layers[i] = layer.cpu()
del layer
torch.cuda.empty_cache()
inps, outs = outs, inps

if model.model.decoder.final_layer_norm is not None:
model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
if model.model.decoder.project_out is not None:
model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
model.lm_head = model.lm_head.to(dev)

testenc = testenc.to(dev)
nlls = []
for i in range(nsamples):
hidden_states = inps[i].unsqueeze(0)
if model.model.decoder.final_layer_norm is not None:
hidden_states = model.model.decoder.final_layer_norm(hidden_states)
if model.model.decoder.project_out is not None:
hidden_states = model.model.decoder.project_out(hidden_states)
lm_logits = model.lm_head(hidden_states)
shift_logits = lm_logits[:, :-1, :].contiguous()
shift_labels = testenc[
:, (i * model.seqlen):((i + 1) * model.seqlen)
][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
print(ppl.item())

model.config.use_cache = use_cache
return ppl.item()
108 changes: 67 additions & 41 deletions model/llama.py → model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from eval import *
from collections import defaultdict
from pprint import pprint
from modelutils import quantize_model, quantize_model_gptq, add_act_quant_wrapper, reorder_model
from modelutils_llama import quantize_model_llama, reorder_model_llama, quantize_model_gptq_llama, add_act_quant_wrapper_llama
from modelutils_opt import quantize_model_opt, reorder_model_opt, quantize_model_gptq_opt, add_act_quant_wrapper_opt
from parallel_utils import map_layers_to_multi_gpus
from LMClass import LMClass
import lm_eval
from eval import pattern_match
from lm_eval import tasks as lm_tasks
from lm_eval import evaluator as lm_evaluator


def get_llama(model):
Expand All @@ -23,6 +25,19 @@ def skip(*args, **kwargs):
model.seqlen = 2048
return model

def get_opt(model):
import torch
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
from transformers import OPTForCausalLM
model = OPTForCausalLM.from_pretrained(model, torch_dtype=torch.float16)
model.seqlen = model.config.max_position_embeddings
return model


if __name__ == '__main__':
import argparse
from datautils import *
Expand Down Expand Up @@ -159,14 +174,26 @@ def skip(*args, **kwargs):
args = parser.parse_args()

model_name = args.model.lower().split('/')[-1]
if "llama" not in model_name:
model_name = args.model.split('/')[-2]
assert model_name != None, "Please check the model path."

model = get_llama(args.model)
if "llama" in args.model.lower():
model = get_llama(args.model)
get_act_stats_func = get_act_stats_llama
reorder_model_func = reorder_model_llama
add_act_quant_wrapper_func = add_act_quant_wrapper_llama
quantize_model_gptq_func = quantize_model_gptq_llama
quantize_model_func = quantize_model_llama
eval_func = llama_eval
elif "opt" in args.model.lower():
model = get_opt(args.model)
get_act_stats_func = get_act_stats_opt
reorder_model_func = reorder_model_opt
add_act_quant_wrapper_func = add_act_quant_wrapper_opt
quantize_model_gptq_func = quantize_model_gptq_opt
quantize_model_func = quantize_model_opt
eval_func = opt_eval
model.eval()

from pathlib import Path
import pathlib
import os

if args.reorder:
Expand All @@ -175,7 +202,7 @@ def skip(*args, **kwargs):
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print("Getting activation stats...")
act_scales = get_act_stats(
act_scales = get_act_stats_func(
model, dataloader, DEV, metric=args.act_sort_metric
)

Expand All @@ -193,78 +220,77 @@ def skip(*args, **kwargs):
reorder_index = torch.load(index_filename)

print("Reordering model...")
model = reorder_model(
model = reorder_model_func(
model, device=DEV, args=args, reorder_index=reorder_index
)

if args.static == True:
assert args.abits < 16, "Static quantization should quantize A."
if args.cache_index == False:
dataloader, testloader = get_loaders(
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print("Getting scales...")
scales = get_act_scales(model, dataloader, DEV, args)
torch.save(scales, f'../saved/{model_str}_scales_{args.dataset}_{args.act_group_size}.pt')
else:
print("Getting cached scales...")
scales = torch.load(f'../saved/{model_str}_scales_{args.dataset}_{args.act_group_size}.pt')
else:
scales = defaultdict(lambda: None)

if args.abits < 16:
print("Inserting activations quantizers ...")
model = add_act_quant_wrapper(model, device=DEV, args=args, scales=scales)
scales = defaultdict(lambda: None)
model = add_act_quant_wrapper_func(model, device=DEV, args=args, scales=scales)

if args.wbits < 16:
print("Quantizing...")
if args.use_gptq:
dataloader, testloader = get_loaders(
args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
)
model = quantize_model_gptq(model, device=DEV, args=args, dataloader=dataloader)
model = quantize_model_gptq_func(model, device=DEV, args=args, dataloader=dataloader)
else:
model = quantize_model(model, device=DEV, args=args)
model = quantize_model_func(model, device=DEV, args=args)


if args.eval_ppl:
datasets = ['wikitext2', 'ptb', 'c4', 'ptb-new', 'c4-new']
datasets = ['wikitext2', 'ptb', 'c4']

for dataset in datasets:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen
)
print(f"Evaluating {dataset} ...")
ppl = llama_eval(model, testloader, DEV)
ppl = eval_func(model, testloader, DEV)

print(f"targetResult,{dataset},{ppl:.3f}")

# eval zero shot accuracy on commonsense datasets
if args.eval_common_sense:

lm = LMClass(args, model)
lm.seqlen = 2048
lm.model.eval()
for param in lm.model.parameters():
param.requires_grad = False

if args.multigpu:
map_layers_to_multi_gpus(lm.model.model.layers)
input_device = lm.model.model.layers[0].device
output_device = lm.model.model.layers[-1].device
assert input_device == output_device
lm._device = input_device
lm.model.model.embed_tokens.to(input_device)
lm.model.model.norm.to(output_device)
lm.model.lm_head.to(output_device)
if "llama" in args.model.lower():
map_layers_to_multi_gpus(lm.model.model.layers)
input_device = lm.model.model.layers[0].device
output_device = lm.model.model.layers[-1].device
assert input_device == output_device
lm._device = input_device
lm.model.model.embed_tokens.to(input_device)
lm.model.model.norm.to(output_device)
lm.model.lm_head.to(output_device)
elif "opt" in args.model.lower():
map_layers_to_multi_gpus(lm.model.model.decoder.layers)
input_device = lm.model.model.decoder.layers[0].device
output_device = lm.model.model.decoder.layers[-1].device
assert input_device == output_device
lm._device = input_device
lm.model.model.decoder.embed_tokens.to(input_device)
lm.model.model.decoder.embed_positions.to(input_device)
lm.model.model.decoder.final_layer_norm.to(input_device)
lm.model.lm_head.to(output_device)
else:
lm._device = DEV
lm.model = lm.model.to(lm.device)

results = {}
tasks_str = "piqa,arc_easy,arc_challenge,boolq,hellaswag,winogrande"
task_names = pattern_match(tasks_str.split(","), lm_eval.tasks.ALL_TASKS)
task_names = pattern_match(tasks_str.split(","), lm_tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")

task_dict = lm_eval.tasks.get_task_dict(task_names)
t_results = lm_eval.evaluator.evaluate(
task_dict = lm_tasks.get_task_dict(task_names)
t_results = lm_evaluator.evaluate(
lm,
task_dict,
num_fewshot=args.lm_eval_num_fewshot,
Expand Down
Loading

0 comments on commit e89479c

Please sign in to comment.