-
Notifications
You must be signed in to change notification settings - Fork 191
/
benchmark.py
123 lines (104 loc) · 3.55 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Usage:
# Please first install awq/kernels
# then directly run CUDA_VISIBLE_DEVICES=0 python benchmark.py
import argparse
import torch
import time
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
import tinychat.utils.constants
from tinychat.utils.load_quant import load_awq_model
from awq.quantize.quantizer import real_quantize_model_weight
from tinychat.utils.tune import tune_all_wqlinears, device_warmup
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
def skip(*args, **kwargs):
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type", type=str, default="LLaMa", help="type of the model"
)
parser.add_argument(
"--model_path",
type=str,
default="/data/llm/checkpoints/vicuna-hf/vicuna-7b",
help="path to the model",
)
parser.add_argument("--q_group_size", type=int, default=128)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Wheter to print more information.",
)
parser.add_argument(
"--max_seq_len",
type=int,
default=2048,
help="maximum sequence length for kv cache",
)
parser.add_argument(
"--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
)
args = parser.parse_args()
tinychat.utils.constants.max_batch_size = args.max_batch_size
tinychat.utils.constants.max_seq_len = args.max_seq_len
from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM
modeling_utils._init_weights = False
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
device = "cuda:0"
# exLLaMA benchmarking parameters.
context_length = 4
gen_length = 200
input_ids = [1 for _ in range(context_length)]
model_type_dict = {
"llama": LlamaForCausalLM,
"falcon": FalconForCausalLM,
"mpt": MPTForCausalLM,
}
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
assert args.model_type.lower() in [
"llama",
"falcon",
"mpt",
], "We only support llama & falcon & mpt now"
model = model_type_dict[args.model_type.lower()](config).half()
real_quantize_model_weight(
model,
w_bit=4,
q_config=dict(q_group_size=args.q_group_size, zero_point=True),
init_only=True,
)
model = model.to(device)
# tune_all_wqlinears(model)
make_quant_attn(model, device)
make_quant_norm(model)
make_fused_mlp(model)
device_warmup(device)
print("huggingface ckpt loaded")
print(model)
time_lis = []
start_pos = 0
print("Benchmarking...")
with torch.inference_mode():
for i in range(gen_length):
torch.cuda.synchronize()
t_st = time.time()
if i == 0:
inputs = torch.as_tensor([input_ids], device=device)
else:
inputs = torch.as_tensor([[token]], device=device)
out = model(inputs, start_pos=start_pos)
start_pos += out.shape[1]
torch.cuda.synchronize()
t_ed = time.time()
time_lis.append(t_ed - t_st)
token = out[:, -1].max(1)[1].unsqueeze(1)
if args.verbose:
print(i, np.median(time_lis))
print(f"Speed: {1 / np.median(time_lis)} tokens per second.")
if __name__ == "__main__":
main()