forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script for benchmarking serving throughput (vllm-project#145)
- Loading branch information
1 parent
da5ddcd
commit 311490a
Showing
10 changed files
with
421 additions
and
415 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
"""Benchmark the latency of processing a single batch of requests.""" | ||
import argparse | ||
import time | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
"""Benchmark online serving throughput. | ||
On the server side, run one of the following commands: | ||
(CacheFlow backend) | ||
python -m cacheflow.entrypoints.simple_fastapi_frontend \ | ||
--disable-log-requests --model <your_model> | ||
(TGI backend) | ||
./launch_hf_server.sh <your_model> | ||
On the client side, run: | ||
python benchmarks/benchmark_serving.py \ | ||
--backend <backend> \ | ||
--tokenizer <your_model> --dataset <target_dataset> \ | ||
--request-rate <request_rate> | ||
""" | ||
import argparse | ||
import asyncio | ||
import json | ||
import random | ||
import time | ||
from typing import AsyncGenerator, List, Tuple | ||
|
||
import aiohttp | ||
import numpy as np | ||
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase | ||
|
||
# (prompt len, output len, latency) | ||
REQUEST_LATENCY: List[Tuple[int, int, float]] = [] | ||
|
||
|
||
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase: | ||
config = AutoConfig.from_pretrained(model_name) | ||
if config.model_type == "llama": | ||
# A workaround for potential protobuf errors. | ||
model_name = "hf-internal-testing/llama-tokenizer" | ||
return AutoTokenizer.from_pretrained(model_name) | ||
|
||
|
||
def sample_requests( | ||
dataset_path: str, | ||
num_requests: int, | ||
tokenizer: PreTrainedTokenizerBase, | ||
) -> List[Tuple[str, int, int]]: | ||
# Load the dataset. | ||
with open(dataset_path) as f: | ||
dataset = json.load(f) | ||
# Filter out the conversations with less than 2 turns. | ||
dataset = [ | ||
data for data in dataset | ||
if len(data["conversations"]) >= 2 | ||
] | ||
# Only keep the first two turns of each conversation. | ||
dataset = [ | ||
(data["conversations"][0]["value"], data["conversations"][1]["value"]) | ||
for data in dataset | ||
] | ||
|
||
# Tokenize the prompts and completions. | ||
prompts = [prompt for prompt, _ in dataset] | ||
prompt_token_ids = tokenizer(prompts).input_ids | ||
completions = [completion for _, completion in dataset] | ||
completion_token_ids = tokenizer(completions).input_ids | ||
tokenized_dataset = [] | ||
for i in range(len(dataset)): | ||
output_len = len(completion_token_ids[i]) | ||
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) | ||
|
||
# Filter out too long sequences. | ||
filtered_dataset: List[Tuple[str, int, int]] = [] | ||
for prompt, prompt_token_ids, output_len in tokenized_dataset: | ||
prompt_len = len(prompt_token_ids) | ||
if prompt_len < 4 or output_len < 4: | ||
# Prune too short sequences. | ||
# This is because TGI causes errors when the input or output length | ||
# is too short. | ||
continue | ||
if prompt_len > 1024 or prompt_len + output_len > 2048: | ||
# Prune too long sequences. | ||
continue | ||
filtered_dataset.append((prompt, prompt_len, output_len)) | ||
|
||
# Sample the requests. | ||
sampled_requests = random.sample(filtered_dataset, num_requests) | ||
return sampled_requests | ||
|
||
|
||
async def get_request( | ||
input_requests: List[Tuple[str, int, int]], | ||
request_rate: float, | ||
) -> AsyncGenerator[Tuple[str, int, int], None]: | ||
input_requests = iter(input_requests) | ||
for request in input_requests: | ||
yield request | ||
|
||
if request_rate == float("inf"): | ||
# If the request rate is infinity, then we don't need to wait. | ||
continue | ||
# Sample the request interval from the exponential distribution. | ||
interval = np.random.exponential(1.0 / request_rate) | ||
# The next request will be sent after the interval. | ||
await asyncio.sleep(interval) | ||
|
||
|
||
async def send_request( | ||
backend: str, | ||
api_url: str, | ||
prompt: str, | ||
prompt_len: int, | ||
output_len: int, | ||
best_of: int, | ||
use_beam_search: bool, | ||
) -> None: | ||
request_start_time = time.time() | ||
|
||
headers = {"User-Agent": "Benchmark Client"} | ||
if backend == "cacheflow": | ||
pload = { | ||
"prompt": prompt, | ||
"n": 1, | ||
"best_of": best_of, | ||
"use_beam_search": use_beam_search, | ||
"temperature": 0.0 if use_beam_search else 1.0, | ||
"top_p": 1.0, | ||
"max_tokens": output_len, | ||
"ignore_eos": True, | ||
"stream": False, | ||
} | ||
elif backend == "tgi": | ||
assert not use_beam_search | ||
params = { | ||
"best_of": best_of, | ||
"max_new_tokens": output_len, | ||
"do_sample": True, | ||
} | ||
pload = { | ||
"inputs": prompt, | ||
"parameters": params, | ||
} | ||
else: | ||
raise ValueError(f"Unknown backend: {backend}") | ||
|
||
timeout = aiohttp.ClientTimeout(total=3 * 3600) | ||
async with aiohttp.ClientSession(timeout=timeout) as session: | ||
while True: | ||
async with session.post(api_url, headers=headers, json=pload) as response: | ||
chunks = [] | ||
async for chunk, _ in response.content.iter_chunks(): | ||
chunks.append(chunk) | ||
output = b"".join(chunks).decode("utf-8") | ||
output = json.loads(output) | ||
|
||
# Re-send the request if it failed. | ||
if "error" not in output: | ||
break | ||
|
||
request_end_time = time.time() | ||
request_latency = request_end_time - request_start_time | ||
REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) | ||
|
||
|
||
async def benchmark( | ||
backend: str, | ||
api_url: str, | ||
input_requests: List[Tuple[str, int, int]], | ||
best_of: int, | ||
use_beam_search: bool, | ||
request_rate: float, | ||
) -> None: | ||
tasks: List[asyncio.Task] = [] | ||
async for request in get_request(input_requests, request_rate): | ||
prompt, prompt_len, output_len = request | ||
task = asyncio.create_task(send_request(backend, api_url, prompt, | ||
prompt_len, output_len, | ||
best_of, use_beam_search)) | ||
tasks.append(task) | ||
await asyncio.gather(*tasks) | ||
|
||
|
||
def main(args: argparse.Namespace): | ||
print(args) | ||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
|
||
api_url = f"http://{args.host}:{args.port}/generate" | ||
tokenizer = get_tokenizer(args.tokenizer) | ||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) | ||
|
||
benchmark_start_time = time.time() | ||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, | ||
args.use_beam_search, args.request_rate)) | ||
benchmark_end_time = time.time() | ||
benchmark_time = benchmark_end_time - benchmark_start_time | ||
print(f"Total time: {benchmark_time:.2f} s") | ||
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") | ||
|
||
# Compute the latency statistics. | ||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) | ||
print(f"Average latency: {avg_latency:.2f} s") | ||
avg_per_token_latency = np.mean([ | ||
latency / (prompt_len + output_len) | ||
for prompt_len, output_len, latency in REQUEST_LATENCY | ||
]) | ||
print(f"Average latency per token: {avg_per_token_latency:.2f} s") | ||
avg_per_output_token_latency = np.mean([ | ||
latency / output_len | ||
for _, output_len, latency in REQUEST_LATENCY | ||
]) | ||
print("Average latency per output token: " | ||
f"{avg_per_output_token_latency:.2f} s") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Benchmark the online serving throughput.") | ||
parser.add_argument("--backend", type=str, default="cacheflow", | ||
choices=["cacheflow", "tgi"]) | ||
parser.add_argument("--host", type=str, default="localhost") | ||
parser.add_argument("--port", type=int, default=8001) | ||
parser.add_argument("--dataset", type=str, required=True, | ||
help="Path to the dataset.") | ||
parser.add_argument("--tokenizer", type=str, required=True, | ||
help="Name or path of the tokenizer.") | ||
parser.add_argument("--best-of", type=int, default=1, | ||
help="Generates `best_of` sequences per prompt and " | ||
"returns the best one.") | ||
parser.add_argument("--use-beam-search", action="store_true") | ||
parser.add_argument("--num-prompts", type=int, default=1000, | ||
help="Number of prompts to process.") | ||
parser.add_argument("--request-rate", type=float, default=float("inf"), | ||
help="Number of requests per second. If this is inf, " | ||
"then all the requests are sent at time 0. " | ||
"Otherwise, we use Poisson process to synthesize " | ||
"the request arrival times.") | ||
parser.add_argument("--seed", type=int, default=0) | ||
args = parser.parse_args() | ||
main(args) |
Oops, something went wrong.