Skip to content

Commit 9459f0a

Browse files
tlrmchlsmthrobertgshaw2-redhat
authored andcommitted
[P/D Disagg][Benchmarking] One request at a time benchmarking for P/D (#79)
* Benchmark one concurrent req Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * Updates Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * restore Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> * Improve random requests, switch up initial test Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> --------- Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 964472b commit 9459f0a

File tree

3 files changed

+449
-1
lines changed

3 files changed

+449
-1
lines changed
Lines changed: 382 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,382 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import argparse
3+
import asyncio
4+
import logging
5+
import random
6+
import time
7+
from dataclasses import dataclass
8+
from typing import Optional
9+
10+
import aiohttp # Import aiohttp
11+
import numpy as np
12+
from backend_request_func import RequestFuncInput, RequestFuncOutput
13+
from benchmark_dataset import RandomDataset, SampleRequest
14+
from tqdm import tqdm
15+
16+
try:
17+
from vllm.transformers_utils.tokenizer import get_tokenizer
18+
except ImportError:
19+
from backend_request_func import get_tokenizer
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
@dataclass
25+
class BenchmarkMetrics:
26+
completed: int
27+
total_input: int
28+
total_output: int
29+
mean_ttft_ms: float
30+
median_ttft_ms: float
31+
std_ttft_ms: float
32+
percentiles_ttft_ms: list[tuple[float, float]]
33+
mean_itl_ms: float
34+
median_itl_ms: float
35+
std_itl_ms: float
36+
percentiles_itl_ms: list[tuple[float, float]]
37+
mean_e2el_ms: float
38+
median_e2el_ms: float
39+
std_e2el_ms: float
40+
percentiles_e2el_ms: list[tuple[float, float]]
41+
42+
43+
async def reset_cache(reset_url: str):
44+
"""Sends a POST request to reset the prefix cache."""
45+
logger.debug("Resetting prefix cache at %s", reset_url)
46+
try:
47+
async with (aiohttp.ClientSession() as session, session.post(reset_url)
48+
as response):
49+
response.raise_for_status(
50+
) # Raise an exception for bad status codes (4xx or 5xx)
51+
logger.debug("Prefix cache reset successful: %s", response.status)
52+
except aiohttp.ClientConnectorError as e:
53+
logger.error("Failed to connect to cache reset endpoint %s: %s}",
54+
reset_url, e)
55+
except aiohttp.ClientResponseError as e:
56+
logger.error("Cache reset request failed with status %s: %s", e.status,
57+
e.message)
58+
except Exception as e:
59+
logger.error("An unexpected error occurred during cache reset: %s", e)
60+
61+
62+
async def sequential_benchmark(
63+
backend: str,
64+
api_url: str,
65+
model_id: str,
66+
tokenizer,
67+
input_requests: list[SampleRequest],
68+
request_func,
69+
selected_percentiles: list[float],
70+
cache_reset_url: Optional[str] = None,
71+
):
72+
"""
73+
Benchmark that processes requests sequentially, waiting for each to complete
74+
before starting the next one. Resets prefix cache between requests.
75+
"""
76+
outputs = []
77+
78+
pbar = tqdm(total=len(input_requests))
79+
80+
# Small request to force a forward pass.
81+
# Used for resetting the prefix cache.
82+
dummy_req_input = RequestFuncInput(
83+
model=model_id,
84+
prompt="0",
85+
api_url=api_url,
86+
prompt_len=1,
87+
output_len=1,
88+
)
89+
90+
print("Starting initial single prompt test run...")
91+
test_output = await request_func(request_func_input=dummy_req_input)
92+
if not test_output.success:
93+
raise ValueError(
94+
"Initial test run failed - Please check your configuration. "
95+
"Error: %s", test_output.error)
96+
else:
97+
print("Initial test run completed. Starting sequential benchmark...")
98+
99+
benchmark_start_time = time.perf_counter()
100+
101+
# Process requests sequentially
102+
for request in input_requests:
103+
prompt, prompt_len, output_len = (request.prompt, request.prompt_len,
104+
request.expected_output_len)
105+
106+
logger.info("Sending request with len %s", request.prompt_len)
107+
logger.debug("Request str: \"%s\"", request.prompt[:50])
108+
request_start_time = time.perf_counter()
109+
110+
request_func_input = RequestFuncInput(
111+
model=model_id,
112+
prompt=prompt,
113+
api_url=api_url,
114+
prompt_len=prompt_len,
115+
output_len=output_len,
116+
)
117+
118+
output = await request_func(request_func_input=request_func_input)
119+
120+
request_end_time = time.perf_counter()
121+
# Add timing information
122+
if output.success and not hasattr(output, "latency"):
123+
output.latency = request_end_time - request_start_time
124+
logger.info("Finished request with latency %.4f s", output.latency)
125+
126+
outputs.append(output)
127+
pbar.update(1)
128+
129+
# Reset prefix cache if configured, except after the very last request
130+
if cache_reset_url:
131+
await request_func(request_func_input=dummy_req_input)
132+
await reset_cache(cache_reset_url)
133+
134+
pbar.close()
135+
136+
benchmark_duration = time.perf_counter() - benchmark_start_time
137+
138+
# Calculate metrics
139+
metrics = calculate_metrics(
140+
input_requests=input_requests,
141+
outputs=outputs,
142+
dur_s=benchmark_duration,
143+
tokenizer=tokenizer,
144+
selected_percentiles=selected_percentiles,
145+
)
146+
147+
print_results(metrics, benchmark_duration)
148+
149+
result = {
150+
"duration":
151+
benchmark_duration,
152+
"completed":
153+
metrics.completed,
154+
"total_input_tokens":
155+
metrics.total_input,
156+
"total_output_tokens":
157+
metrics.total_output,
158+
"input_lens": [request.prompt_len for request in input_requests],
159+
"output_lens":
160+
[output.output_tokens if output.success else 0 for output in outputs],
161+
"ttfts": [output.ttft for output in outputs if output.success],
162+
"itls": [output.itl for output in outputs if output.success],
163+
"generated_texts":
164+
[output.generated_text for output in outputs if output.success],
165+
"errors": [output.error for output in outputs if not output.success],
166+
}
167+
168+
# Add summary statistics
169+
for stat_name in ["ttft", "itl", "e2el"]:
170+
for metric_name in ["mean", "median", "std"]:
171+
result[f"{metric_name}_{stat_name}_ms"] = getattr(
172+
metrics, f"{metric_name}_{stat_name}_ms")
173+
174+
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
175+
p_word = str(int(p)) if int(p) == p else str(p)
176+
result[f"p{p_word}_{stat_name}_ms"] = value
177+
178+
return result
179+
180+
181+
def calculate_metrics(
182+
input_requests: list[SampleRequest],
183+
outputs: list[RequestFuncOutput],
184+
dur_s: float,
185+
tokenizer,
186+
selected_percentiles: list[float],
187+
) -> BenchmarkMetrics:
188+
"""Calculate benchmark metrics from results."""
189+
total_input = 0
190+
completed = 0
191+
total_output = 0
192+
ttfts = []
193+
itls = []
194+
e2els = []
195+
196+
for i, output in enumerate(outputs):
197+
if output.success:
198+
output_len = output.output_tokens
199+
200+
if not output_len:
201+
# Use tokenizer to count output tokens if not provided
202+
output_len = len(
203+
tokenizer(output.generated_text,
204+
add_special_tokens=False).input_ids)
205+
206+
total_output += output_len
207+
total_input += input_requests[i].prompt_len
208+
209+
if hasattr(output, "ttft") and output.ttft is not None:
210+
ttfts.append(output.ttft)
211+
212+
if hasattr(output, "itl") and output.itl:
213+
# Ensure itl is a list of floats
214+
if isinstance(output.itl, list):
215+
itls.extend(output.itl)
216+
else:
217+
logger.warning(
218+
"Expected list for ITL but got %s. Appending as is.",
219+
type(output.itl))
220+
itls.append(output.itl)
221+
222+
if hasattr(output, "latency") and output.latency is not None:
223+
e2els.append(output.latency)
224+
225+
completed += 1
226+
227+
return BenchmarkMetrics(
228+
completed=completed,
229+
total_input=total_input,
230+
total_output=total_output,
231+
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
232+
median_ttft_ms=np.median(ttfts or [0]) * 1000,
233+
std_ttft_ms=np.std(ttfts or [0]) * 1000,
234+
percentiles_ttft_ms=[(p, np.percentile(ttfts or [0], p) * 1000)
235+
for p in selected_percentiles],
236+
mean_itl_ms=np.mean(itls or [0]) * 1000,
237+
median_itl_ms=np.median(itls or [0]) * 1000,
238+
std_itl_ms=np.std(itls or [0]) * 1000,
239+
percentiles_itl_ms=[(p, np.percentile(itls or [0], p) * 1000)
240+
for p in selected_percentiles],
241+
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
242+
median_e2el_ms=np.median(e2els or [0]) * 1000,
243+
std_e2el_ms=np.std(e2els or [0]) * 1000,
244+
percentiles_e2el_ms=[(p, np.percentile(e2els or [0], p) * 1000)
245+
for p in selected_percentiles],
246+
)
247+
248+
249+
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
250+
"""Print benchmark results in a formatted way."""
251+
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
252+
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
253+
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
254+
benchmark_duration))
255+
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
256+
print("{:<40} {:<10}".format("Total generated tokens:",
257+
metrics.total_output))
258+
259+
def print_metric_stats(metric_name, header):
260+
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
261+
print("{:<40} {:<10.2f}".format(
262+
f"Mean {metric_name} (ms):",
263+
getattr(metrics, f"mean_{metric_name.lower()}_ms")))
264+
print("{:<40} {:<10.2f}".format(
265+
f"Median {metric_name} (ms):",
266+
getattr(metrics, f"median_{metric_name.lower()}_ms")))
267+
268+
for p, value in getattr(metrics,
269+
f"percentiles_{metric_name.lower()}_ms"):
270+
p_word = str(int(p)) if int(p) == p else str(p)
271+
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):",
272+
value))
273+
274+
print_metric_stats("TTFT", "Time to First Token")
275+
print_metric_stats("ITL", "Inter-token Latency")
276+
print_metric_stats("E2EL", "End-to-end Latency")
277+
print("=" * 60)
278+
279+
280+
async def main_async(args):
281+
# Import needed functions based on your setup
282+
from backend_request_func import ASYNC_REQUEST_FUNCS
283+
284+
backend = args.backend
285+
model_id = args.model
286+
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
287+
288+
# Set up API URL
289+
if args.base_url is not None:
290+
api_url = f"{args.base_url}{args.endpoint}"
291+
else:
292+
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
293+
294+
# Set up Cache Reset URL
295+
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
296+
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
297+
298+
# Get tokenizer
299+
tokenizer = get_tokenizer(tokenizer_id,
300+
trust_remote_code=args.trust_remote_code)
301+
302+
# Get request function
303+
if backend in ASYNC_REQUEST_FUNCS:
304+
request_func = ASYNC_REQUEST_FUNCS[backend]
305+
else:
306+
raise ValueError(f"Unknown backend: {backend}")
307+
308+
input_requests = RandomDataset().sample(
309+
tokenizer=tokenizer,
310+
num_requests=args.num_requests,
311+
prefix_len=0,
312+
input_len=args.input_len,
313+
output_len=args.output_len,
314+
range_ratio=0.0,
315+
)
316+
317+
# Run benchmark
318+
result = await sequential_benchmark(
319+
backend=backend,
320+
api_url=api_url,
321+
model_id=model_id,
322+
tokenizer=tokenizer,
323+
input_requests=input_requests,
324+
request_func=request_func,
325+
selected_percentiles=[50, 90, 95, 99],
326+
cache_reset_url=cache_reset_url,
327+
)
328+
329+
return result
330+
331+
332+
def main(args):
333+
print(args)
334+
random.seed(args.seed)
335+
np.random.seed(args.seed)
336+
337+
asyncio.run(main_async(args))
338+
339+
340+
if __name__ == "__main__":
341+
parser = argparse.ArgumentParser(
342+
description="Sequential benchmark for LLM serving")
343+
parser.add_argument("--backend",
344+
type=str,
345+
default="vllm",
346+
help="Backend to use for requests")
347+
parser.add_argument("--base-url",
348+
type=str,
349+
default=None,
350+
help="Server base URL (overrides --host and --port)")
351+
parser.add_argument("--host", type=str, default="127.0.0.1")
352+
parser.add_argument("--port", type=int, default=8000)
353+
parser.add_argument("--endpoint",
354+
type=str,
355+
default="/v1/completions",
356+
help="API endpoint")
357+
parser.add_argument("--model",
358+
type=str,
359+
required=True,
360+
help="Name of the model")
361+
parser.add_argument("--tokenizer",
362+
type=str,
363+
help="Name of the tokenizer (defaults to model name)")
364+
parser.add_argument("--num-requests",
365+
type=int,
366+
default=100,
367+
help="Number of requests to process")
368+
parser.add_argument("--input-len",
369+
type=int,
370+
default=128,
371+
help="Input len for generated prompts")
372+
parser.add_argument("--output-len",
373+
type=int,
374+
default=None,
375+
help="Override output len for requests")
376+
parser.add_argument("--seed", type=int, default=42)
377+
parser.add_argument("--trust-remote-code",
378+
action="store_true",
379+
help="Trust remote code from HuggingFace")
380+
381+
args = parser.parse_args()
382+
main(args)

tests/v1/kv_connector/nixl_integration/run_accuracy_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ set -xe
44
# Models to run
55
MODELS=(
66
"Qwen/Qwen3-0.6B"
7+
"deepseek-ai/deepseek-vl2-tiny"
78
)
89

910
# Number of prefill and decode instances to create

0 commit comments

Comments
 (0)