Skip to content

Add script for benchmarking serving throughput #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Jun 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
473c5b8
Minor fix
WoosukKwon Jun 10, 2023
a644a9b
Minor
WoosukKwon Jun 10, 2023
67ed51c
Minor
WoosukKwon Jun 10, 2023
83acd5e
Minor
WoosukKwon Jun 10, 2023
4957281
Add log-requests option to AsyncLLMServer
WoosukKwon Jun 10, 2023
c6b38d2
[WIP] Add benchmark_serving.py
WoosukKwon Jun 10, 2023
5210de0
Minor
WoosukKwon Jun 10, 2023
d4df348
Delete unused files
WoosukKwon Jun 10, 2023
fab12d6
Minor
WoosukKwon Jun 10, 2023
3ddadf4
Add docstring
WoosukKwon Jun 10, 2023
4269b11
Bugfix
WoosukKwon Jun 10, 2023
af8974d
Minor
WoosukKwon Jun 10, 2023
f8dee6e
Minor
WoosukKwon Jun 10, 2023
d181f10
Add script to launch HF server
WoosukKwon Jun 10, 2023
fc02a02
Add HF backend
WoosukKwon Jun 10, 2023
99d9ce3
Minor
WoosukKwon Jun 10, 2023
bc9ec63
Bugfix
WoosukKwon Jun 10, 2023
9477f2f
Filter out long prompts
WoosukKwon Jun 10, 2023
51a5332
Minor fix
WoosukKwon Jun 10, 2023
6b0d77b
Merge branch 'main' into benchmark-llama
WoosukKwon Jun 10, 2023
00d158d
Repeat failed requests
WoosukKwon Jun 10, 2023
0c55c40
Stream=False
WoosukKwon Jun 10, 2023
bcb8e16
Minor
WoosukKwon Jun 10, 2023
6a7baaa
Prune short sequences
WoosukKwon Jun 10, 2023
071b4aa
Add 1 hour timeout
WoosukKwon Jun 10, 2023
983cf97
Increase timeout
WoosukKwon Jun 10, 2023
b55b1ee
Add shortcut
WoosukKwon Jun 11, 2023
c45a2dd
Simplify
WoosukKwon Jun 11, 2023
66f8c60
Merge branch 'opt' into benchmark-llama
WoosukKwon Jun 11, 2023
a1b513e
n -> best_of
WoosukKwon Jun 11, 2023
72d6a63
Minor
WoosukKwon Jun 11, 2023
44bc461
Add latency stats
WoosukKwon Jun 11, 2023
6990fc5
Increase max_best_of in HF server
WoosukKwon Jun 11, 2023
2c610bd
Merge branch 'main' into benchmark-llama
WoosukKwon Jun 11, 2023
5687f10
hf -> tgi
WoosukKwon Jun 13, 2023
672fbbd
Add HF backend
WoosukKwon Jun 13, 2023
60bccc4
Fix batching
WoosukKwon Jun 13, 2023
b7fcade
Fix a bug & Add tqdm
WoosukKwon Jun 13, 2023
6accbfd
Minor
WoosukKwon Jun 14, 2023
c7360d1
Fix
WoosukKwon Jun 15, 2023
bf1bae6
Comment
WoosukKwon Jun 15, 2023
7bebe29
Add docstring
WoosukKwon Jun 15, 2023
5c1b852
Comment
WoosukKwon Jun 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions benchmarks/benchmark_async_llm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def main(args: argparse.Namespace):
prompts = [f"Tell me a story with more than {''.join([str(i+1)] * 5)} words"
for i in range(args.n_threads)]

api_url = f"http://{args.host}:{args.port}/generate"
headers = {"User-Agent": "CacheFlow Benchmark Client"}
ploads = [{
"prompt": p,
Expand All @@ -19,8 +20,8 @@ def main(args: argparse.Namespace):
} for p in prompts]

def send_request(results, i):
response = requests.post(args.api_url, headers=headers,
json=ploads[i], stream=True)
response = requests.post(api_url, headers=headers, json=ploads[i],
stream=True)
results[i] = response

# use args.n_threads to prompt the backend
Expand Down Expand Up @@ -50,7 +51,8 @@ def send_request(results, i):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api-url", type=str, default="http://localhost:8001/generate")
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--max-tokens", type=int, default=128)
parser.add_argument("--n-threads", type=int, default=128)
args = parser.parse_args()
Expand Down
1 change: 1 addition & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import time

Expand Down
237 changes: 237 additions & 0 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""Benchmark online serving throughput.

On the server side, run one of the following commands:
(CacheFlow backend)
python -m cacheflow.entrypoints.simple_fastapi_frontend \
--disable-log-requests --model <your_model>

(TGI backend)
./launch_hf_server.sh <your_model>

On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
--tokenizer <your_model> --dataset <target_dataset> \
--request-rate <request_rate>
"""
import argparse
import asyncio
import json
import random
import time
from typing import AsyncGenerator, List, Tuple

import aiohttp
import numpy as np
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase

# (prompt len, output len, latency)
REQUEST_LATENCY: List[Tuple[int, int, float]] = []


def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import from cacheflow.server.tokenizer_utils?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't use it because I thought that's an internal function that cacheflow does not expose to the users.

config = AutoConfig.from_pretrained(model_name)
if config.model_type == "llama":
# A workaround for potential protobuf errors.
model_name = "hf-internal-testing/llama-tokenizer"
return AutoTokenizer.from_pretrained(model_name)


def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]

# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))

# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))

# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests


async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request

if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)


async def send_request(
backend: str,
api_url: str,
prompt: str,
prompt_len: int,
output_len: int,
best_of: int,
use_beam_search: bool,
) -> None:
request_start_time = time.time()

headers = {"User-Agent": "Benchmark Client"}
if backend == "cacheflow":
pload = {
"prompt": prompt,
"n": 1,
"best_of": best_of,
"use_beam_search": use_beam_search,
"temperature": 0.0 if use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": output_len,
"ignore_eos": True,
"stream": False,
}
elif backend == "tgi":
assert not use_beam_search
params = {
"best_of": best_of,
"max_new_tokens": output_len,
"do_sample": True,
}
pload = {
"inputs": prompt,
"parameters": params,
}
else:
raise ValueError(f"Unknown backend: {backend}")

timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks).decode("utf-8")
output = json.loads(output)

# Re-send the request if it failed.
if "error" not in output:
break

request_end_time = time.time()
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))


async def benchmark(
backend: str,
api_url: str,
input_requests: List[Tuple[str, int, int]],
best_of: int,
use_beam_search: bool,
request_rate: float,
) -> None:
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt,
prompt_len, output_len,
best_of, use_beam_search))
tasks.append(task)
await asyncio.gather(*tasks)


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)

api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)

benchmark_start_time = time.time()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
args.use_beam_search, args.request_rate))
benchmark_end_time = time.time()
benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s")
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")

# Compute the latency statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
print(f"Average latency: {avg_latency:.2f} s")
avg_per_token_latency = np.mean([
latency / (prompt_len + output_len)
for prompt_len, output_len, latency in REQUEST_LATENCY
])
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([
latency / output_len
for _, output_len, latency in REQUEST_LATENCY
])
print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="cacheflow",
choices=["cacheflow", "tgi"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this argument actually ask for model_name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, is it confusing?

help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
main(args)
Loading