Skip to content

Commit 04863a9

Browse files
[example] Update Llama Inference example (#5629)
* [example] add infernece benchmark llama3 * revise inference config - arg * remove unused args * add llama generation demo script * fix init rope in llama policy * add benchmark-llama3 - cleanup
1 parent 12f10d5 commit 04863a9

File tree

4 files changed

+323
-12
lines changed

4 files changed

+323
-12
lines changed

colossalai/inference/modeling/policy/nopadding_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,5 +100,5 @@ def module_policy(self):
100100
return policy
101101

102102
def postprocess(self):
103-
init_to_get_rotary(self.model.model)
103+
init_to_get_rotary(self.model.model, self.model.config.rope_theta)
104104
return self.model

examples/inference/benchmark_llama.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,22 @@
5151
num_key_value_heads=40,
5252
max_position_embeddings=4096,
5353
),
54+
"llama3-8b": transformers.LlamaConfig(
55+
hidden_size=4096,
56+
intermediate_size=14336,
57+
num_attention_heads=32,
58+
num_hidden_layers=32,
59+
num_key_value_heads=8,
60+
max_position_embeddings=8192,
61+
),
62+
"llama3-70b": transformers.LlamaConfig(
63+
hidden_size=8192,
64+
intermediate_size=28672,
65+
num_attention_heads=64,
66+
num_hidden_layers=80,
67+
num_key_value_heads=8,
68+
max_position_embeddings=8192,
69+
),
5470
}
5571

5672

@@ -66,7 +82,7 @@ def print_details_info(model_config, args, whole_end2end, total_token_num):
6682
msg += "-------Perf Summary-------\n"
6783
whole_avg_latency = whole_end2end / (total_token_num)
6884
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
69-
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
85+
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12
7086
if args.dtype in ["fp16", "bf16"]:
7187
num_bytes = 2
7288
else:
@@ -90,11 +106,11 @@ def benchmark_inference(args):
90106
config = CONFIG_MAP[args.model]
91107
config.pad_token_id = config.eos_token_id
92108
if args.test_random_weight:
93-
model = transformers.LlamaForCausalLM(config).cuda()
109+
model = transformers.LlamaForCausalLM(config)
94110
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
95111
else:
96112
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
97-
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
113+
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
98114
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
99115

100116
model = model.eval()
@@ -111,12 +127,12 @@ def benchmark_inference(args):
111127
if args.mode == "colossalai":
112128
inference_config = InferenceConfig(
113129
dtype=args.dtype,
114-
micro_batch_size=args.mb_size,
115130
max_batch_size=mbsz,
116131
max_input_len=args.seq_len,
117132
max_output_len=args.output_len,
118133
prefill_ratio=1.2,
119134
block_size=32,
135+
tp_size=args.tp_size,
120136
use_cuda_kernel=True,
121137
)
122138
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
@@ -142,7 +158,8 @@ def benchmark_inference(args):
142158

143159
generation_config = GenerationConfig(
144160
pad_token_id=tokenizer.pad_token_id,
145-
max_new_tokens=args.output_len,
161+
max_length=args.seq_len + args.output_len,
162+
# max_new_tokens=args.output_len,
146163
)
147164

148165
N_WARMUP_STEPS = 2
@@ -219,7 +236,7 @@ def hybrid_inference(rank, world_size, port, args):
219236
@rerun_if_address_is_in_use()
220237
@clear_cache_before_run()
221238
def benchmark(args):
222-
spawn(hybrid_inference, nprocs=args.tp_size * args.pp_size, args=args)
239+
spawn(hybrid_inference, nprocs=args.tp_size, args=args)
223240

224241

225242
if __name__ == "__main__":
@@ -229,18 +246,15 @@ def benchmark(args):
229246
"--model",
230247
default="toy",
231248
help="the size of model",
232-
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
249+
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b", "llama3-8b", "llama3-70b"],
233250
)
234251
parser.add_argument("--model_path", type=str, default=None, help="The pretrained weights path")
235252
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
236253
parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step")
237254
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
238-
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
239-
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
240-
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
255+
parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallelism size")
241256
parser.add_argument("--output_len", type=int, default=128, help="Output length")
242257
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
243-
parser.add_argument("-v", "--verbose", default=False, action="store_true")
244258
parser.add_argument(
245259
"--test_random_weight", default=False, action="store_true", help="whether to test random weight"
246260
)
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import argparse
2+
import time
3+
from contextlib import nullcontext
4+
5+
import torch
6+
import transformers
7+
from transformers import AutoTokenizer, GenerationConfig
8+
9+
import colossalai
10+
from colossalai.accelerator import get_accelerator
11+
from colossalai.cluster import DistCoordinator
12+
from colossalai.inference.config import InferenceConfig
13+
from colossalai.inference.core.engine import InferenceEngine
14+
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
15+
16+
GIGABYTE = 1024**3
17+
MEGABYTE = 1024**2
18+
N_WARMUP_STEPS = 2
19+
20+
CONFIG_MAP = {
21+
"toy": transformers.LlamaConfig(num_hidden_layers=4),
22+
"llama-7b": transformers.LlamaConfig(
23+
hidden_size=4096,
24+
intermediate_size=11008,
25+
num_attention_heads=32,
26+
num_hidden_layers=32,
27+
num_key_value_heads=32,
28+
max_position_embeddings=2048,
29+
),
30+
"llama-13b": transformers.LlamaConfig(
31+
hidden_size=5120,
32+
intermediate_size=13824,
33+
num_attention_heads=40,
34+
num_hidden_layers=40,
35+
num_key_value_heads=40,
36+
max_position_embeddings=2048,
37+
),
38+
"llama2-7b": transformers.LlamaConfig(
39+
hidden_size=4096,
40+
intermediate_size=11008,
41+
num_attention_heads=32,
42+
num_hidden_layers=32,
43+
num_key_value_heads=32,
44+
max_position_embeddings=4096,
45+
),
46+
"llama2-13b": transformers.LlamaConfig(
47+
hidden_size=5120,
48+
intermediate_size=13824,
49+
num_attention_heads=40,
50+
num_hidden_layers=40,
51+
num_key_value_heads=40,
52+
max_position_embeddings=4096,
53+
),
54+
"llama3-8b": transformers.LlamaConfig(
55+
hidden_size=4096,
56+
intermediate_size=14336,
57+
num_attention_heads=32,
58+
num_hidden_layers=32,
59+
num_key_value_heads=8,
60+
max_position_embeddings=8192,
61+
),
62+
"llama3-70b": transformers.LlamaConfig(
63+
hidden_size=8192,
64+
intermediate_size=28672,
65+
num_attention_heads=64,
66+
num_hidden_layers=80,
67+
num_key_value_heads=8,
68+
max_position_embeddings=8192,
69+
),
70+
}
71+
72+
73+
def data_gen(batch_size: int = 4, seq_len: int = 512):
74+
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device())
75+
return input_ids.tolist()
76+
77+
78+
def print_details_info(model_config, whole_end2end, total_token_num, dtype, coordinator=None):
79+
if coordinator is None:
80+
coordinator = DistCoordinator()
81+
msg = "-------Perf Summary-------\n"
82+
whole_avg_latency = whole_end2end / (total_token_num)
83+
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
84+
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12
85+
if dtype in ["fp16", "bf16"]:
86+
num_bytes = 2
87+
elif dtype == "fp32":
88+
num_bytes = 4
89+
else:
90+
raise ValueError(f"Unsupported dtype {dtype}")
91+
92+
msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
93+
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
94+
msg += f"Throughput: {total_token_num / whole_end2end:.2f} tokens/s\n"
95+
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
96+
if torch.cuda.is_available():
97+
msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n"
98+
msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n"
99+
msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n"
100+
101+
coordinator.print_on_master(msg)
102+
103+
104+
def benchmark_inference(args):
105+
coordinator = DistCoordinator()
106+
107+
config = CONFIG_MAP[args.model]
108+
config.pad_token_id = config.eos_token_id
109+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
110+
if args.model_path is not None:
111+
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
112+
else:
113+
# Random weights
114+
model = transformers.LlamaForCausalLM(config)
115+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
116+
if args.dtype == "fp16":
117+
model = model.half()
118+
elif args.dtype == "bf16":
119+
model = model.to(torch.bfloat16)
120+
121+
inference_config = InferenceConfig(
122+
dtype=args.dtype,
123+
max_batch_size=args.batch_size,
124+
max_input_len=args.max_seq_len,
125+
max_output_len=args.max_output_len,
126+
prefill_ratio=1.2,
127+
block_size=32,
128+
tp_size=args.tp_size,
129+
use_cuda_kernel=True,
130+
)
131+
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
132+
133+
data = data_gen(args.batch_size, args.max_seq_len)
134+
generation_config = GenerationConfig(
135+
pad_token_id=tokenizer.pad_token_id,
136+
max_length=args.max_seq_len + args.max_output_len,
137+
# max_new_tokens=args.max_output_len,
138+
)
139+
coordinator.print_on_master(f"Generation Config: \n{generation_config.to_dict()}")
140+
141+
ctx = (
142+
torch.profiler.profile(
143+
record_shapes=True,
144+
with_stack=True,
145+
with_modules=True,
146+
activities=[
147+
torch.profiler.ProfilerActivity.CPU,
148+
torch.profiler.ProfilerActivity.CUDA,
149+
],
150+
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
151+
on_trace_ready=torch.profiler.tensorboard_trace_handler(
152+
f"./tb_log_{args.batch_size}_{args.max_seq_len}_{args.max_output_len}"
153+
),
154+
)
155+
if args.profile
156+
else nullcontext()
157+
)
158+
with ctx:
159+
for _ in range(N_WARMUP_STEPS):
160+
engine.generate(prompts_token_ids=data, generation_config=generation_config)
161+
if args.profile:
162+
ctx.step()
163+
if args.nsys:
164+
torch.cuda.cudart().cudaProfilerStart()
165+
166+
torch.cuda.synchronize()
167+
whole_end2end = time.perf_counter()
168+
output, output_tokens_list = engine.generate(
169+
prompts_token_ids=data, generation_config=generation_config, return_token_ids=True
170+
)
171+
torch.cuda.synchronize()
172+
whole_end2end = time.perf_counter() - whole_end2end
173+
174+
total_token_num = sum([len(output_tokens) for output_tokens in output_tokens_list])
175+
coordinator.print_on_master(f"total_token_num: {total_token_num}")
176+
if args.nsys:
177+
torch.cuda.cudart().cudaProfilerStop()
178+
if args.profile:
179+
ctx.step()
180+
181+
print_details_info(model.config, whole_end2end, total_token_num, args.dtype, coordinator=coordinator)
182+
183+
184+
def inference(rank, world_size, port, args):
185+
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
186+
benchmark_inference(args)
187+
188+
189+
@rerun_if_address_is_in_use()
190+
@clear_cache_before_run()
191+
def benchmark(args):
192+
spawn(inference, nprocs=args.tp_size, args=args)
193+
194+
195+
# python benchmark_llama3.py -m llama3-8b -b 16 -s 256 -o 256
196+
if __name__ == "__main__":
197+
parser = argparse.ArgumentParser()
198+
parser.add_argument(
199+
"-m",
200+
"--model",
201+
default="llama3-8b",
202+
help="The version of Llama model",
203+
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b", "llama3-8b", "llama3-70b"],
204+
)
205+
parser.add_argument("-p", "--model_path", type=str, default=None, help="The pretrained weights path")
206+
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
207+
parser.add_argument("-s", "--max_seq_len", type=int, default=8, help="input sequence length")
208+
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Output length")
209+
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
210+
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
211+
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
212+
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
213+
214+
args = parser.parse_args()
215+
216+
benchmark(args)

0 commit comments

Comments
 (0)