forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix latency benchmark script (vllm-project#118)
- Loading branch information
1 parent
19d2899
commit 3f942ac
Showing
2 changed files
with
43 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,71 +1,75 @@ | ||
import argparse | ||
import time | ||
from typing import List | ||
|
||
from tqdm import tqdm | ||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from cacheflow.core.server import ( | ||
add_server_arguments, process_server_arguments, | ||
init_local_server_and_frontend_with_arguments) | ||
from cacheflow.sampling_params import SamplingParams | ||
from cacheflow import LLM, SamplingParams | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
server, frontend = init_local_server_and_frontend_with_arguments(args) | ||
print(args) | ||
|
||
# Process all the requests in a single batch if possible. | ||
# NOTE(woosuk): If the request cannot be processed in a single batch, | ||
# the server will automatically process the request in multiple batches. | ||
llm = LLM( | ||
model=args.model, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
max_num_seqs=args.batch_size, | ||
max_num_batched_tokens=args.batch_size * args.input_len, | ||
) | ||
|
||
sampling_params = SamplingParams( | ||
n=args.n, | ||
temperature=0.0 if args.use_beam_search else 1.0, | ||
top_p=1.0, | ||
use_beam_search=args.use_beam_search, | ||
stop_token_ids=set(), | ||
ignore_eos=True, | ||
max_tokens=args.output_len, | ||
) | ||
print(sampling_params) | ||
input_token_ids = [0] * args.input_len | ||
dummy_prompts = [""] * args.batch_size | ||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size | ||
|
||
def profile_step(profile=False): | ||
def run_to_completion(profile: bool = False): | ||
if profile: | ||
torch.cuda.cudart().cudaProfilerStart() | ||
for _ in range(args.batch_size): | ||
dummy_prompt = "" | ||
frontend._add_query(dummy_prompt, input_token_ids, sampling_params) | ||
server.add_sequence_groups(frontend.get_inputs()) | ||
start_time = time.time() | ||
while True: | ||
server.step() | ||
if not server.has_unfinished_requests(): | ||
break | ||
|
||
llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids, | ||
use_tqdm=False) | ||
|
||
end_time = time.time() | ||
latency = end_time - start_time | ||
if profile: | ||
torch.cuda.cudart().cudaProfilerStop() | ||
return latency | ||
|
||
print("Warm up step") | ||
profile_step() | ||
print("Warming up...") | ||
run_to_completion(profile=False) | ||
|
||
# Benchmark. | ||
latencies = [] | ||
for _ in tqdm(range(3), desc="Profile step"): | ||
latencies.append(profile_step()) | ||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): | ||
latencies.append(run_to_completion(profile=False)) | ||
print(f'Avg latency: {np.mean(latencies)} seconds') | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Benchmark the latency of decoding a single sentence.') | ||
parser = add_server_arguments(parser) | ||
description='Benchmark the latency of processing a single batch of ' | ||
'requests till completion.') | ||
parser.add_argument('--model', type=str, default='facebook/opt-125m') | ||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||
parser.add_argument('--input-len', type=int, default=32) | ||
parser.add_argument('--output-len', type=int, default=128) | ||
parser.add_argument('--batch-size', type=int, default=8) | ||
parser.add_argument('--n', type=int, default=1) | ||
parser.add_argument('--n', type=int, default=1, | ||
help='Number of generated sequences per prompt.') | ||
parser.add_argument('--use-beam-search', action='store_true') | ||
parser.add_argument('--num-iters', type=int, default=3, | ||
help='Number of iterations to run.') | ||
args = parser.parse_args() | ||
args = process_server_arguments(args) | ||
args.max_num_batched_tokens = max( | ||
args.max_num_batched_tokens, args.batch_size * args.input_len) | ||
print(args) | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters