|
1 | 1 | import argparse |
2 | 2 | import time |
3 | | -from typing import List |
4 | 3 |
|
5 | | -from tqdm import tqdm |
6 | 4 | import numpy as np |
7 | 5 | import torch |
| 6 | +from tqdm import tqdm |
8 | 7 |
|
9 | | -from cacheflow.core.server import ( |
10 | | - add_server_arguments, process_server_arguments, |
11 | | - init_local_server_and_frontend_with_arguments) |
12 | | -from cacheflow.sampling_params import SamplingParams |
| 8 | +from cacheflow import LLM, SamplingParams |
13 | 9 |
|
14 | 10 |
|
15 | 11 | def main(args: argparse.Namespace): |
16 | | - server, frontend = init_local_server_and_frontend_with_arguments(args) |
| 12 | + print(args) |
| 13 | + |
| 14 | + # Process all the requests in a single batch if possible. |
| 15 | + # NOTE(woosuk): If the request cannot be processed in a single batch, |
| 16 | + # the server will automatically process the request in multiple batches. |
| 17 | + llm = LLM( |
| 18 | + model=args.model, |
| 19 | + tensor_parallel_size=args.tensor_parallel_size, |
| 20 | + max_num_seqs=args.batch_size, |
| 21 | + max_num_batched_tokens=args.batch_size * args.input_len, |
| 22 | + ) |
17 | 23 |
|
18 | 24 | sampling_params = SamplingParams( |
19 | 25 | n=args.n, |
20 | 26 | temperature=0.0 if args.use_beam_search else 1.0, |
21 | 27 | top_p=1.0, |
22 | 28 | use_beam_search=args.use_beam_search, |
23 | | - stop_token_ids=set(), |
| 29 | + ignore_eos=True, |
24 | 30 | max_tokens=args.output_len, |
25 | 31 | ) |
26 | 32 | print(sampling_params) |
27 | | - input_token_ids = [0] * args.input_len |
| 33 | + dummy_prompts = [""] * args.batch_size |
| 34 | + dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size |
28 | 35 |
|
29 | | - def profile_step(profile=False): |
| 36 | + def run_to_completion(profile: bool = False): |
30 | 37 | if profile: |
31 | 38 | torch.cuda.cudart().cudaProfilerStart() |
32 | | - for _ in range(args.batch_size): |
33 | | - dummy_prompt = "" |
34 | | - frontend._add_query(dummy_prompt, input_token_ids, sampling_params) |
35 | | - server.add_sequence_groups(frontend.get_inputs()) |
36 | 39 | start_time = time.time() |
37 | | - while True: |
38 | | - server.step() |
39 | | - if not server.has_unfinished_requests(): |
40 | | - break |
| 40 | + |
| 41 | + llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids, |
| 42 | + use_tqdm=False) |
| 43 | + |
41 | 44 | end_time = time.time() |
42 | 45 | latency = end_time - start_time |
43 | 46 | if profile: |
44 | 47 | torch.cuda.cudart().cudaProfilerStop() |
45 | 48 | return latency |
46 | 49 |
|
47 | | - print("Warm up step") |
48 | | - profile_step() |
| 50 | + print("Warming up...") |
| 51 | + run_to_completion(profile=False) |
49 | 52 |
|
50 | 53 | # Benchmark. |
51 | 54 | latencies = [] |
52 | | - for _ in tqdm(range(3), desc="Profile step"): |
53 | | - latencies.append(profile_step()) |
| 55 | + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): |
| 56 | + latencies.append(run_to_completion(profile=False)) |
54 | 57 | print(f'Avg latency: {np.mean(latencies)} seconds') |
55 | 58 |
|
56 | 59 |
|
57 | 60 | if __name__ == '__main__': |
58 | 61 | parser = argparse.ArgumentParser( |
59 | | - description='Benchmark the latency of decoding a single sentence.') |
60 | | - parser = add_server_arguments(parser) |
| 62 | + description='Benchmark the latency of processing a single batch of ' |
| 63 | + 'requests till completion.') |
| 64 | + parser.add_argument('--model', type=str, default='facebook/opt-125m') |
| 65 | + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) |
61 | 66 | parser.add_argument('--input-len', type=int, default=32) |
62 | 67 | parser.add_argument('--output-len', type=int, default=128) |
63 | 68 | parser.add_argument('--batch-size', type=int, default=8) |
64 | | - parser.add_argument('--n', type=int, default=1) |
| 69 | + parser.add_argument('--n', type=int, default=1, |
| 70 | + help='Number of generated sequences per prompt.') |
65 | 71 | parser.add_argument('--use-beam-search', action='store_true') |
| 72 | + parser.add_argument('--num-iters', type=int, default=3, |
| 73 | + help='Number of iterations to run.') |
66 | 74 | args = parser.parse_args() |
67 | | - args = process_server_arguments(args) |
68 | | - args.max_num_batched_tokens = max( |
69 | | - args.max_num_batched_tokens, args.batch_size * args.input_len) |
70 | | - print(args) |
71 | 75 | main(args) |
0 commit comments