Skip to content

Commit

Permalink
Fix latency benchmark script (vllm-project#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored May 23, 2023
1 parent 19d2899 commit 3f942ac
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
62 changes: 33 additions & 29 deletions benchmark/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,75 @@
import argparse
import time
from typing import List

from tqdm import tqdm
import numpy as np
import torch
from tqdm import tqdm

from cacheflow.core.server import (
add_server_arguments, process_server_arguments,
init_local_server_and_frontend_with_arguments)
from cacheflow.sampling_params import SamplingParams
from cacheflow import LLM, SamplingParams


def main(args: argparse.Namespace):
server, frontend = init_local_server_and_frontend_with_arguments(args)
print(args)

# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
)

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
stop_token_ids=set(),
ignore_eos=True,
max_tokens=args.output_len,
)
print(sampling_params)
input_token_ids = [0] * args.input_len
dummy_prompts = [""] * args.batch_size
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size

def profile_step(profile=False):
def run_to_completion(profile: bool = False):
if profile:
torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
dummy_prompt = ""
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time()
while True:
server.step()
if not server.has_unfinished_requests():
break

llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
use_tqdm=False)

end_time = time.time()
latency = end_time - start_time
if profile:
torch.cuda.cudart().cudaProfilerStop()
return latency

print("Warm up step")
profile_step()
print("Warming up...")
run_to_completion(profile=False)

# Benchmark.
latencies = []
for _ in tqdm(range(3), desc="Profile step"):
latencies.append(profile_step())
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile=False))
print(f'Avg latency: {np.mean(latencies)} seconds')


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of decoding a single sentence.')
parser = add_server_arguments(parser)
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--n', type=int, default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.')
args = parser.parse_args()
args = process_server_arguments(args)
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)
print(args)
main(args)
12 changes: 10 additions & 2 deletions cacheflow/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,26 @@ def generate(
self,
prompts: List[str],
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Initialize tqdm.
if use_tqdm:
pbar = tqdm(total=len(prompts), desc="Processed prompts")

# Add requests to the server.
for prompt in prompts:
for i in range(len(prompts)):
prompt = prompts[i]
if prompt_token_ids is None:
token_ids = None
else:
token_ids = prompt_token_ids[i]
request_id = str(next(self.request_counter))
self.llm_server.add_request(request_id, prompt, sampling_params)
self.llm_server.add_request(request_id, prompt, sampling_params,
token_ids)

# Run the server.
outputs: List[RequestOutput] = []
Expand Down

0 comments on commit 3f942ac

Please sign in to comment.