Skip to content

Commit

Permalink
Add script for benchmarking serving throughput (vllm-project#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Jun 15, 2023
1 parent da5ddcd commit 311490a
Show file tree
Hide file tree
Showing 10 changed files with 421 additions and 415 deletions.
8 changes: 5 additions & 3 deletions benchmarks/benchmark_async_llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
for i in range(args.n_threads)]

api_url = f"http://{args.host}:{args.port}/generate"
headers = {"User-Agent": "CacheFlow Benchmark Client"}
ploads = [{
"prompt": p,
Expand All @@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
} for p in prompts]

def send_request(results, i):
response = requests.post(args.api_url, headers=headers,
json=ploads[i], stream=True)
response = requests.post(api_url, headers=headers, json=ploads[i],
stream=True)
results[i] = response

# use args.n_threads to prompt the backend
Expand Down Expand Up @@ -50,7 +51,8 @@ def send_request(results, i):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import time

Expand Down
237 changes: 237 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""Benchmark online serving throughput.
On the server side, run one of the following commands:
(CacheFlow backend)
python -m cacheflow.entrypoints.simple_fastapi_frontend \
--disable-log-requests --model <your_model>
(TGI backend)
./launch_hf_server.sh <your_model>
On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
--tokenizer <your_model> --dataset <target_dataset> \
--request-rate <request_rate>
"""
import argparse
import asyncio
import json
import random
import time
from typing import AsyncGenerator, List, Tuple

import aiohttp
import numpy as np
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase

# (prompt len, output len, latency)
REQUEST_LATENCY: List[Tuple[int, int, float]] = []


def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
return AutoTokenizer.from_pretrained(model_name)


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))

# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests


async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request

if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)


async def send_request(
backend: str,
api_url: str,
prompt: str,
prompt_len: int,
output_len: int,
best_of: int,
use_beam_search: bool,
) -> None:
request_start_time = time.time()

headers = {"User-Agent": "Benchmark Client"}
if backend == "cacheflow":
pload = {
"prompt": prompt,
"n": 1,
"best_of": best_of,
"use_beam_search": use_beam_search,
"temperature": 0.0 if use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": output_len,
"ignore_eos": True,
"stream": False,
}
elif backend == "tgi":
assert not use_beam_search
params = {
"best_of": best_of,
"max_new_tokens": output_len,
"do_sample": True,
}
pload = {
"inputs": prompt,
"parameters": params,
}
else:
raise ValueError(f"Unknown backend: {backend}")

timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks).decode("utf-8")
output = json.loads(output)

# Re-send the request if it failed.
if "error" not in output:
break

request_end_time = time.time()
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))


async def benchmark(
backend: str,
api_url: str,
input_requests: List[Tuple[str, int, int]],
best_of: int,
use_beam_search: bool,
request_rate: float,
) -> None:
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt,
prompt_len, output_len,
best_of, use_beam_search))
tasks.append(task)
await asyncio.gather(*tasks)


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)

api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

benchmark_start_time = time.time()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
args.use_beam_search, args.request_rate))
benchmark_end_time = time.time()
benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s")
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")

# Compute the latency statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
print(f"Average latency: {avg_latency:.2f} s")
avg_per_token_latency = np.mean([
latency / (prompt_len + output_len)
for prompt_len, output_len, latency in REQUEST_LATENCY
])
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([
latency / output_len
for _, output_len, latency in REQUEST_LATENCY
])
print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="cacheflow",
choices=["cacheflow", "tgi"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True,
help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
main(args)
Loading

0 comments on commit 311490a

Please sign in to comment.