Skip to content

Commit bacad9c

Browse files
authored
Support OpenAI API server in benchmark_serving.py (vllm-project#2172)
1 parent c3149dc commit bacad9c

File tree

2 files changed

+51
-32
lines changed

2 files changed

+51
-32
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,6 @@ _build/
181181
# hip files generated by PyTorch
182182
*.hip
183183
*_hip*
184+
185+
# Benchmark dataset
186+
*.json

benchmarks/benchmark_serving.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import aiohttp
2626
import numpy as np
27+
from tqdm.asyncio import tqdm
2728
from transformers import PreTrainedTokenizerBase
2829
from vllm.transformers_utils.tokenizer import get_tokenizer
2930

@@ -40,15 +41,10 @@ def sample_requests(
4041
with open(dataset_path) as f:
4142
dataset = json.load(f)
4243
# Filter out the conversations with less than 2 turns.
43-
dataset = [
44-
data for data in dataset
45-
if len(data["conversations"]) >= 2
46-
]
44+
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
4745
# Only keep the first two turns of each conversation.
48-
dataset = [
49-
(data["conversations"][0]["value"], data["conversations"][1]["value"])
50-
for data in dataset
51-
]
46+
dataset = [(data["conversations"][0]["value"],
47+
data["conversations"][1]["value"]) for data in dataset]
5248

5349
# Tokenize the prompts and completions.
5450
prompts = [prompt for prompt, _ in dataset]
@@ -98,6 +94,7 @@ async def get_request(
9894

9995
async def send_request(
10096
backend: str,
97+
model: str,
10198
api_url: str,
10299
prompt: str,
103100
prompt_len: int,
@@ -120,6 +117,8 @@ async def send_request(
120117
"ignore_eos": True,
121118
"stream": False,
122119
}
120+
if model is not None:
121+
pload["model"] = model
123122
elif backend == "tgi":
124123
assert not use_beam_search
125124
params = {
@@ -137,7 +136,8 @@ async def send_request(
137136
timeout = aiohttp.ClientTimeout(total=3 * 3600)
138137
async with aiohttp.ClientSession(timeout=timeout) as session:
139138
while True:
140-
async with session.post(api_url, headers=headers, json=pload) as response:
139+
async with session.post(api_url, headers=headers,
140+
json=pload) as response:
141141
chunks = []
142142
async for chunk, _ in response.content.iter_chunks():
143143
chunks.append(chunk)
@@ -155,6 +155,7 @@ async def send_request(
155155

156156
async def benchmark(
157157
backend: str,
158+
model: str,
158159
api_url: str,
159160
input_requests: List[Tuple[str, int, int]],
160161
best_of: int,
@@ -164,25 +165,27 @@ async def benchmark(
164165
tasks: List[asyncio.Task] = []
165166
async for request in get_request(input_requests, request_rate):
166167
prompt, prompt_len, output_len = request
167-
task = asyncio.create_task(send_request(backend, api_url, prompt,
168-
prompt_len, output_len,
169-
best_of, use_beam_search))
168+
task = asyncio.create_task(
169+
send_request(backend, model, api_url, prompt, prompt_len,
170+
output_len, best_of, use_beam_search))
170171
tasks.append(task)
171-
await asyncio.gather(*tasks)
172+
await tqdm.gather(*tasks)
172173

173174

174175
def main(args: argparse.Namespace):
175176
print(args)
176177
random.seed(args.seed)
177178
np.random.seed(args.seed)
178179

179-
api_url = f"http://{args.host}:{args.port}/generate"
180-
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
180+
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
181+
tokenizer = get_tokenizer(args.tokenizer,
182+
trust_remote_code=args.trust_remote_code)
181183
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
182184

183185
benchmark_start_time = time.perf_counter()
184-
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
185-
args.use_beam_search, args.request_rate))
186+
asyncio.run(
187+
benchmark(args.backend, args.model, api_url, input_requests,
188+
args.best_of, args.use_beam_search, args.request_rate))
186189
benchmark_end_time = time.perf_counter()
187190
benchmark_time = benchmark_end_time - benchmark_start_time
188191
print(f"Total time: {benchmark_time:.2f} s")
@@ -196,38 +199,51 @@ def main(args: argparse.Namespace):
196199
for prompt_len, output_len, latency in REQUEST_LATENCY
197200
])
198201
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
199-
avg_per_output_token_latency = np.mean([
200-
latency / output_len
201-
for _, output_len, latency in REQUEST_LATENCY
202-
])
202+
avg_per_output_token_latency = np.mean(
203+
[latency / output_len for _, output_len, latency in REQUEST_LATENCY])
203204
print("Average latency per output token: "
204205
f"{avg_per_output_token_latency:.2f} s")
205206

206207

207208
if __name__ == "__main__":
208209
parser = argparse.ArgumentParser(
209210
description="Benchmark the online serving throughput.")
210-
parser.add_argument("--backend", type=str, default="vllm",
211+
parser.add_argument("--backend",
212+
type=str,
213+
default="vllm",
211214
choices=["vllm", "tgi"])
212215
parser.add_argument("--host", type=str, default="localhost")
213216
parser.add_argument("--port", type=int, default=8000)
214-
parser.add_argument("--dataset", type=str, required=True,
217+
parser.add_argument("--endpoint", type=str, default="/generate")
218+
parser.add_argument("--model", type=str, default=None)
219+
parser.add_argument("--dataset",
220+
type=str,
221+
required=True,
215222
help="Path to the dataset.")
216-
parser.add_argument("--tokenizer", type=str, required=True,
223+
parser.add_argument("--tokenizer",
224+
type=str,
225+
required=True,
217226
help="Name or path of the tokenizer.")
218-
parser.add_argument("--best-of", type=int, default=1,
227+
parser.add_argument("--best-of",
228+
type=int,
229+
default=1,
219230
help="Generates `best_of` sequences per prompt and "
220-
"returns the best one.")
231+
"returns the best one.")
221232
parser.add_argument("--use-beam-search", action="store_true")
222-
parser.add_argument("--num-prompts", type=int, default=1000,
233+
parser.add_argument("--num-prompts",
234+
type=int,
235+
default=1000,
223236
help="Number of prompts to process.")
224-
parser.add_argument("--request-rate", type=float, default=float("inf"),
237+
parser.add_argument("--request-rate",
238+
type=float,
239+
default=float("inf"),
225240
help="Number of requests per second. If this is inf, "
226-
"then all the requests are sent at time 0. "
227-
"Otherwise, we use Poisson process to synthesize "
228-
"the request arrival times.")
241+
"then all the requests are sent at time 0. "
242+
"Otherwise, we use Poisson process to synthesize "
243+
"the request arrival times.")
229244
parser.add_argument("--seed", type=int, default=0)
230-
parser.add_argument('--trust-remote-code', action='store_true',
245+
parser.add_argument('--trust-remote-code',
246+
action='store_true',
231247
help='trust remote code from huggingface')
232248
args = parser.parse_args()
233249
main(args)

0 commit comments

Comments
 (0)