-
Notifications
You must be signed in to change notification settings - Fork 191
/
benchmark_context.py
345 lines (324 loc) · 13.6 KB
/
benchmark_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
# Usage:
# Please first install awq/kernels
# then directly run CUDA_VISIBLE_DEVICES=0 python benchmark.py
import argparse
import torch
import time
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
import tinychat.utils.constants
from tinychat.utils.load_quant import load_awq_model
from awq.quantize.quantizer import real_quantize_model_weight
from tinychat.utils.tune import (
tune_all_wqlinears,
device_warmup,
tune_llava_patch_embedding,
)
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
def skip(*args, **kwargs):
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type", type=str, default="LLaMa", help="type of the model"
)
parser.add_argument(
"--model_path",
type=str,
default="/data/llm/checkpoints/vicuna-hf/vicuna-7b",
help="path to the model",
)
parser.add_argument("--q_group_size", type=int, default=128)
parser.add_argument(
"--verbose",
default=False,
action="store_true",
help="Wheter to print more information.",
)
parser.add_argument(
"--max_seq_len",
type=int,
default=2048,
help="maximum sequence length for kv cache",
)
parser.add_argument(
"--max_batch_size", type=int, default=1, help="maximum batch size for kv cache"
)
parser.add_argument(
"--flash_attn",
action="store_true",
help="whether to use flash attention",
)
parser.add_argument(
"--chunk_prefilling",
action="store_true",
help="If used, in context stage, the history tokens will not be recalculated, greatly speeding up the calculation",
)
parser.add_argument(
"--context_length",
type=list,
nargs="+",
help="The length of input. And if chunk_prefilling used, this serves as the length of tokens from history rounds.",
)
parser.add_argument(
"--question_length",
type=list,
nargs="+",
help="The length of new input. Only useful and necessary when benchmarking chunk_prefilling method",
)
args = parser.parse_args()
# some checks
assert (args.question_length is not None and args.chunk_prefilling) or (
not args.chunk_prefilling
)
# We support fixing a certain kind of length
if args.chunk_prefilling:
if len(args.context_length) == 1 and len(args.question_length) > 1:
args.context_length = [
args.context_length[0] for _ in range(len(args.question_length))
]
elif len(args.question_length) == 1 and len(args.context_length) > 1:
args.question_length = [
args.question_length[0] for _ in range(len(args.context_length))
]
elif len(args.question_length) != len(args.context_length):
raise ValueError(
"The number of items in the question_length and context_length is expected to be either one or equal!"
)
tinychat.utils.constants.max_batch_size = args.max_batch_size
tinychat.utils.constants.max_seq_len = args.max_seq_len
from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM
from tinychat.models.vila_llama import VilaLlamaForCausalLM
modeling_utils._init_weights = False
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
device = "cuda:0"
model_type_dict = {
"llama": LlamaForCausalLM,
"falcon": FalconForCausalLM,
"mpt": MPTForCausalLM,
}
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
assert args.model_type.lower() in [
"llama",
"falcon",
"mpt",
"vila",
], "We only support llama & falcon & mpt & vila now"
if "vila" in args.model_type.lower():
model = VilaLlamaForCausalLM(config).half()
real_quantize_model_weight(
model.llm,
w_bit=4,
q_config=dict(q_group_size=args.q_group_size, zero_point=True),
init_only=True,
)
make_quant_attn(model.llm, device, args.flash_attn)
make_quant_norm(model.llm)
make_fused_mlp(model.llm)
model = model.to(device)
device_warmup(device)
tune_llava_patch_embedding(model.get_vision_tower(), device=device)
if not args.chunk_prefilling:
image_num = [
int(int("".join(i)) * 0.75 / 196) for i in args.context_length
] # consider about three thirds of the history tokens are images
if sum(image_num) > 0:
image_tensor = 2 * torch.rand((max(image_num), 3, 384, 384)) - 1
image_tensor = image_tensor.half().to(device)
else:
image_tensor = None
print("huggingface ckpt loaded")
# warming up
input_ids = [1 for _ in range(2048)]
inputs = torch.as_tensor([input_ids], device=device)
out = model(
inputs, start_pos=0, chunk_prefilling=args.chunk_prefilling
) # warmup
if not args.chunk_prefilling:
for i, context_length in enumerate(args.context_length):
context_length = int("".join(context_length))
time_lis = []
if image_num[i]:
images = image_tensor[0 : image_num[i], :, :, :]
input_ids = [-200 for _ in range(image_num[i])] + [
1 for _ in range(context_length - 196 * image_num[i])
]
else:
images = None
input_ids = [1 for _ in range(context_length)]
print("-" * 80)
print(
"Context length: {} with {} pictures".format(
context_length, image_num[i]
)
)
with torch.inference_mode():
for i in range(10): # Run ten times and get the average value
start_pos = 0
torch.cuda.synchronize()
t_st = time.time()
inputs = torch.as_tensor([input_ids], device=device)
out = model(
inputs,
start_pos=start_pos,
chunk_prefilling=args.chunk_prefilling,
images=images,
)
start_pos += inputs.shape[1]
torch.cuda.synchronize()
t_ed = time.time()
token = out[:, -1].max(1)[1].unsqueeze(1)
time_lis.append(t_ed - t_st)
if args.verbose:
print(i, t_ed - t_st)
print(f"Time To First Token: {np.mean(time_lis):.5f} s.")
print("-" * 80)
else:
for i, (context_length, question_length) in enumerate(
zip(args.context_length, args.question_length)
):
context_length = int("".join(context_length))
question_length = int("".join(question_length))
input_ids_old = [1 for _ in range(context_length)]
images = None
input_ids_new = [1 for _ in range(question_length)]
time_lis = []
print("-" * 80)
print(
"History length: {} ; Question length: {}".format(
context_length, question_length
)
)
with torch.inference_mode():
for i in range(10): # Run ten times and get the average value
# history rounds
start_pos = 0
if context_length > question_length:
inputs = torch.as_tensor([input_ids_old], device=device)
out = model(
inputs,
start_pos=start_pos,
chunk_prefilling=args.chunk_prefilling,
images=None,
)
start_pos += context_length
# the present round
torch.cuda.synchronize()
t_st = time.time()
inputs = torch.as_tensor([input_ids_new], device=device)
out = model(
inputs,
start_pos=start_pos,
chunk_prefilling=args.chunk_prefilling,
)
start_pos += inputs.shape[1]
torch.cuda.synchronize()
t_ed = time.time()
token = out[:, -1].max(1)[1].unsqueeze(1)
time_lis.append(t_ed - t_st)
if args.verbose:
print(i, t_ed - t_st)
print(
f"Time To First Token of this round: {np.mean(time_lis):.5f} s."
)
print("-" * 80)
else:
model = model_type_dict[args.model_type.lower()](config).half()
real_quantize_model_weight(
model,
w_bit=4,
q_config=dict(q_group_size=args.q_group_size, zero_point=True),
init_only=True,
)
model = model.to(device)
# tune_all_wqlinears(model)
make_quant_attn(model, device, args.flash_attn)
make_quant_norm(model)
make_fused_mlp(model)
device_warmup(device)
print("huggingface ckpt loaded")
# warming up
input_ids = [1 for _ in range(2048)]
inputs = torch.as_tensor([input_ids], device=device)
out = model(
inputs, start_pos=0, chunk_prefilling=args.chunk_prefilling
) # warmup
if not args.chunk_prefilling:
for context_length in args.context_length:
context_length = int("".join(context_length))
input_ids = [1 for _ in range(context_length)]
time_lis = []
print("-" * 80)
print("Context length: {}".format(context_length))
with torch.inference_mode():
for i in range(10): # Run ten times and get the average value
start_pos = 0
torch.cuda.synchronize()
t_st = time.time()
inputs = torch.as_tensor([input_ids], device=device)
out = model(
inputs,
start_pos=start_pos,
chunk_prefilling=args.chunk_prefilling,
)
start_pos += inputs.shape[1]
torch.cuda.synchronize()
t_ed = time.time()
token = out[:, -1].max(1)[1].unsqueeze(1)
time_lis.append(t_ed - t_st)
if args.verbose:
print(i, t_ed - t_st)
print(f"Time To First Token: {np.mean(time_lis):.5f} s.")
print("-" * 80)
else:
for context_length, question_length in zip(
args.context_length, args.question_length
):
context_length = int("".join(context_length))
question_length = int("".join(question_length))
input_ids_old = [1 for _ in range(context_length)]
input_ids_new = [1 for _ in range(question_length)]
time_lis = []
print("-" * 80)
print(
"History length: {} ; Question length: {}".format(
context_length, question_length
)
)
with torch.inference_mode():
for i in range(10): # Run ten times and get the average value
# history rounds
start_pos = 0
if context_length > question_length:
inputs = torch.as_tensor([input_ids_old], device=device)
out = model(
inputs,
start_pos=start_pos,
chunk_prefilling=args.chunk_prefilling,
)
start_pos += inputs.shape[1]
# the present round
torch.cuda.synchronize()
t_st = time.time()
inputs = torch.as_tensor([input_ids_new], device=device)
out = model(
inputs,
start_pos=start_pos,
chunk_prefilling=args.chunk_prefilling,
)
start_pos += inputs.shape[1]
torch.cuda.synchronize()
t_ed = time.time()
token = out[:, -1].max(1)[1].unsqueeze(1)
time_lis.append(t_ed - t_st)
if args.verbose:
print(i, t_ed - t_st)
print(
f"Time To First Token of this round: {np.mean(time_lis):.5f} s."
)
print("-" * 80)
if __name__ == "__main__":
main()