Skip to content

Commit 6fbed91

Browse files
committed
Performance testing tool based on the PyTest testing framework.
1 parent de63b7c commit 6fbed91

File tree

13 files changed

+1172
-67
lines changed

13 files changed

+1172
-67
lines changed

test/common/capture_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from typing import Any, Dict, List
23

34
from common.db_utils import write_to_db
@@ -44,6 +45,7 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
4445

4546
# ---------------- decorator ----------------
4647
def export_vars(func):
48+
@functools.wraps(func)
4749
def wrapper(*args, **kwargs):
4850
result = func(*args, **kwargs)
4951
# If the function returns a dict containing '_data' or 'data', post-process it

test/common/llmperf/__init__.py

Whitespace-only changes.
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import json
2+
import os
3+
import random
4+
from pathlib import Path
5+
from typing import Any, Dict, List
6+
7+
import yaml
8+
from common.llmperf.utils.token_benchmark import run_token_benchmark
9+
from common.llmperf.utils.utils import reset_prefill_cache
10+
11+
12+
def run_test_cases(
13+
llm_api,
14+
model,
15+
timeout,
16+
max_num_completed_requests,
17+
concurrent_requests,
18+
mean_input_tokens,
19+
stddev_input,
20+
mean_output_tokens,
21+
stddev_output,
22+
additional_sampling_params,
23+
timestamp_dir,
24+
server_url,
25+
tokenizer_path,
26+
hit_rate,
27+
):
28+
print(f"[INFO] Total {len(mean_input_tokens)} test cases to be executed")
29+
all_summaries = []
30+
failed_case = []
31+
32+
# Clear proxy environment variables
33+
env = os.environ.copy()
34+
env.pop("http_proxy", None)
35+
env.pop("https_proxy", None)
36+
37+
for i, (
38+
mean_input,
39+
mean_output,
40+
max_completed,
41+
concurrent,
42+
additional_sampling_params,
43+
hit_rate_val,
44+
) in enumerate(
45+
zip(
46+
mean_input_tokens,
47+
mean_output_tokens,
48+
max_num_completed_requests,
49+
concurrent_requests,
50+
additional_sampling_params,
51+
hit_rate,
52+
),
53+
start=1,
54+
):
55+
# for i, case in enumerate(mean_input_tokens):
56+
print(f"\n>>> Executing test case {i} <<<")
57+
reset_prefill_cache(env, server_url)
58+
# Use a fixed random_seed for each test to control PC hit_rate
59+
random_seed = random.randint(1, 100000)
60+
61+
try:
62+
# Determine if two runs are needed (PC hit_rate test)
63+
if hit_rate_val == 0:
64+
summary = run_token_benchmark(
65+
llm_api=llm_api,
66+
model=model,
67+
test_timeout_s=timeout,
68+
max_num_completed_requests=max_completed,
69+
concurrent_requests=concurrent,
70+
mean_input_tokens=mean_input,
71+
stddev_input_tokens=stddev_input,
72+
mean_output_tokens=mean_output,
73+
stddev_output_tokens=stddev_output,
74+
additional_sampling_params=additional_sampling_params,
75+
results_dir=str(timestamp_dir),
76+
random_seed=random_seed,
77+
openai_api_base=server_url + "/v1",
78+
tokenizer_path=tokenizer_path,
79+
user_metadata={"case_idx": i, "phase": "normal"},
80+
)
81+
else:
82+
print(
83+
f"[INFO] hit_rate > 0 detected, entering prefill mode, PC hit rate: {hit_rate_val} %"
84+
)
85+
# hit_rate > 0: first prefill mode
86+
prefill_mean_input = int(mean_input * hit_rate_val / 100)
87+
print(
88+
f"[INFO] Prefill execution: mean_input_tokens={prefill_mean_input}"
89+
)
90+
run_token_benchmark(
91+
llm_api=llm_api,
92+
model=model,
93+
test_timeout_s=timeout,
94+
max_num_completed_requests=max_completed,
95+
concurrent_requests=concurrent,
96+
mean_input_tokens=prefill_mean_input,
97+
stddev_input_tokens=stddev_input,
98+
mean_output_tokens=2,
99+
stddev_output_tokens=stddev_output,
100+
additional_sampling_params=additional_sampling_params,
101+
results_dir=str(timestamp_dir),
102+
random_seed=random_seed,
103+
openai_api_base=server_url + "/v1",
104+
tokenizer_path=tokenizer_path,
105+
user_metadata={"case_idx": i, "phase": "prefill"},
106+
)
107+
reset_prefill_cache(env, server_url)
108+
# Then run normal mode
109+
print("[INFO] Prefill completed, switching to normal mode execution")
110+
summary = run_token_benchmark(
111+
llm_api=llm_api,
112+
model=model,
113+
test_timeout_s=timeout,
114+
max_num_completed_requests=max_completed,
115+
concurrent_requests=concurrent,
116+
mean_input_tokens=mean_input,
117+
stddev_input_tokens=stddev_input,
118+
mean_output_tokens=mean_output,
119+
stddev_output_tokens=stddev_output,
120+
additional_sampling_params=additional_sampling_params,
121+
results_dir=str(timestamp_dir),
122+
random_seed=random_seed,
123+
openai_api_base=server_url + "/v1",
124+
tokenizer_path=tokenizer_path,
125+
user_metadata={"case_idx": i, "phase": "normal"},
126+
)
127+
all_summaries.append(summary)
128+
except Exception as e:
129+
print(f"[Warning] {e}")
130+
failed_case.append(i)
131+
132+
return all_summaries, failed_case
133+
134+
135+
def inference_results(
136+
mean_input_tokens,
137+
mean_output_tokens,
138+
max_num_completed_requests,
139+
concurrent_requests,
140+
additional_sampling_params,
141+
hit_rate,
142+
):
143+
config_file = Path(__file__).parent.parent.parent / "config.yaml"
144+
print("[INFO] Initialization complete, starting main process")
145+
print(f"[INFO] Reading configuration file: {config_file}")
146+
with open(config_file, "r", encoding="utf-8") as f:
147+
config = yaml.safe_load(f)
148+
llm_api = config.get("llm_connection", {}).get("llm_api", "openai")
149+
model = config.get("llm_connection", {}).get("model", "")
150+
test_timeout_s = config.get("llm_connection", {}).get("test_timeout_s", 60000)
151+
stddev_input_tokens = config.get("llm_connection", {}).get(
152+
"stddev_input_tokens", 0
153+
)
154+
stddev_output_tokens = config.get("llm_connection", {}).get(
155+
"stddev_output_tokens", 0
156+
)
157+
timestamp_dir = Path("results")
158+
timestamp_dir.mkdir(parents=True, exist_ok=True)
159+
server_url = config.get("llm_connection", {}).get("server_url", "")
160+
tokenizer_path = config.get("llm_connection", {}).get("tokenizer_path", "")
161+
print(f"[INFO] Created results directory: {timestamp_dir}")
162+
163+
all_summaries, failed_cases = run_test_cases(
164+
llm_api,
165+
model,
166+
test_timeout_s,
167+
max_num_completed_requests,
168+
concurrent_requests,
169+
mean_input_tokens,
170+
stddev_input_tokens,
171+
mean_output_tokens,
172+
stddev_output_tokens,
173+
additional_sampling_params,
174+
timestamp_dir,
175+
server_url,
176+
tokenizer_path,
177+
hit_rate,
178+
)
179+
total = len(mean_input_tokens)
180+
print(
181+
f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}"
182+
)
183+
if failed_cases:
184+
print(f"[WARN] Failed case indices: {failed_cases}")
185+
return all_summaries

test/common/llmperf/utils/__init__.py

Whitespace-only changes.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# TODO (Avnishn): compute metrics in class
2+
INTER_TOKEN_LAT = "inter_token_latency_s"
3+
TTFT = "ttft_s"
4+
E2E_LAT = "end_to_end_latency_s"
5+
NUM_INPUT_TOKENS = "number_input_tokens"
6+
NUM_OUTPUT_TOKENS = "number_output_tokens"
7+
NUM_TOTAL_TOKENS = "number_total_tokens"
8+
REQ_OUTPUT_THROUGHPUT = "request_output_throughput_token_per_s"
9+
ERROR_MSG = "error_msg"
10+
ERROR_CODE = "error_code"
11+
ERROR_CODE_FREQ = "error_code_frequency"
12+
NUM_ERRORS = "number_errors"
13+
OUTPUT_THROUGHPUT = "mean_output_throughput_token_per_s"
14+
NUM_COMPLETED_REQUESTS = "num_completed_requests"
15+
COMPLETED_REQUESTS_PER_MIN = "num_completed_requests_per_min"
16+
ERROR_RATE = "error_rate"
17+
NUM_REQ_STARTED = "num_requests_started"
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
3+
from pydantic import BaseModel
4+
5+
6+
class RequestConfig(BaseModel):
7+
"""The configuration for a request to the LLM API.
8+
9+
Args:
10+
model: The model to use.
11+
prompt: The prompt to provide to the LLM API.
12+
sampling_params: Additional sampling parameters to send with the request.
13+
For more information see the Router app's documentation for the completions
14+
llm_api: The name of the LLM API to send the request to.
15+
metadata: Additional metadata to attach to the request for logging or validation purposes.
16+
"""
17+
18+
model: str
19+
prompt: Tuple[str, int]
20+
sampling_params: Optional[Dict[str, Any]] = None
21+
llm_api: Optional[str] = None
22+
metadata: Optional[Dict[str, Any]] = None
23+
openai_api_base: Optional[str] = ""
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import json
2+
import os
3+
import time
4+
from asyncio import timeout
5+
from pathlib import Path
6+
from typing import Any, Dict, Tuple
7+
8+
import requests
9+
import yaml
10+
from common.llmperf.utils import common_metrics
11+
from common.llmperf.utils.models import RequestConfig
12+
13+
config_file = Path(__file__).parent.parent.parent.parent / "config.yaml"
14+
with open(config_file, "r", encoding="utf-8") as f:
15+
config = yaml.safe_load(f)
16+
stream = config.get("llm_connection", {}).get("stream", True)
17+
ignore_eos = config.get("llm_connection", {}).get("ignore_eos", True)
18+
timeout = config.get("llm_connection", {}).get("timeout", 180)
19+
20+
21+
class OpenAIChatCompletionsClient:
22+
"""
23+
used for sending HTTP requests, receiving token streams, measuring latency, etc.
24+
"""
25+
26+
def llm_request(
27+
self, request_config: RequestConfig
28+
) -> Tuple[Dict[str, Any], str, RequestConfig]:
29+
prompt, prompt_len = request_config.prompt
30+
31+
message = [
32+
{"role": "user", "content": prompt},
33+
]
34+
model = request_config.model
35+
body = {
36+
"model": model,
37+
"messages": message,
38+
"stream": stream,
39+
"ignore_eos": ignore_eos,
40+
}
41+
sampling_params = request_config.sampling_params
42+
body.update(sampling_params or {})
43+
44+
time_to_next_token = []
45+
tokens_received = 0
46+
ttft = 0.0
47+
error_response_code = None
48+
generated_text = ""
49+
error_msg = ""
50+
output_throughput = 0.0
51+
total_request_time = 0.0
52+
flag = False
53+
54+
metrics: Dict[str, Any] = {}
55+
56+
metrics[common_metrics.ERROR_CODE] = None
57+
metrics[common_metrics.ERROR_MSG] = ""
58+
59+
start_time = time.monotonic()
60+
most_recent_received_token_time = start_time
61+
62+
address = request_config.openai_api_base
63+
64+
if not address:
65+
raise ValueError("the environment variable OPENAI_API_BASE must be set.")
66+
key = os.environ.get("OPENAI_API_KEY", "secret_abcdefg")
67+
if not key:
68+
raise ValueError("the environment variable OPENAI_API_KEY must be set.")
69+
headers = {"Authorization": f"Bearer {key}"}
70+
if not address.endswith("/"):
71+
address = address + "/"
72+
address += "chat/completions"
73+
try:
74+
with requests.post(
75+
address,
76+
json=body,
77+
stream=stream,
78+
timeout=timeout,
79+
headers=headers,
80+
) as response:
81+
if response.status_code != 200:
82+
error_msg = response.text
83+
error_response_code = response.status_code
84+
response.raise_for_status()
85+
86+
for chunk in response.iter_lines(chunk_size=None):
87+
if not chunk:
88+
continue
89+
stem = b"data: "
90+
if chunk.startswith(stem):
91+
chunk = chunk[len(stem) :]
92+
# Data might already be bytes or str
93+
if isinstance(chunk, bytes):
94+
chunk = chunk.decode("utf-8", errors="ignore")
95+
if chunk.strip() == "[DONE]":
96+
continue
97+
tokens_received += 1
98+
data = json.loads(chunk)
99+
if "error" in data:
100+
error_msg = data["error"]["message"]
101+
error_response_code = data["error"]["code"]
102+
raise RuntimeError(error_msg)
103+
delta = data["choices"][0]["delta"]
104+
content = delta.get("content", None) or delta.get(
105+
"reasoning_content", ""
106+
)
107+
if content:
108+
if tokens_received != 0 and flag == False:
109+
ttft = time.monotonic() - start_time
110+
flag = True
111+
else:
112+
time_to_next_token.append(
113+
time.monotonic() - most_recent_received_token_time
114+
)
115+
most_recent_received_token_time = time.monotonic()
116+
generated_text += content
117+
118+
total_request_time = time.monotonic() - start_time
119+
if total_request_time > 0:
120+
output_throughput = tokens_received / total_request_time
121+
122+
except Exception as e:
123+
metrics[common_metrics.ERROR_MSG] = error_msg
124+
metrics[common_metrics.ERROR_CODE] = error_response_code
125+
print(f"Warning Or Error: {e}")
126+
print(error_response_code)
127+
128+
metrics[common_metrics.INTER_TOKEN_LAT] = sum(time_to_next_token)
129+
metrics[common_metrics.TTFT] = ttft
130+
metrics[common_metrics.E2E_LAT] = total_request_time
131+
metrics[common_metrics.REQ_OUTPUT_THROUGHPUT] = output_throughput
132+
metrics[common_metrics.NUM_TOTAL_TOKENS] = tokens_received + prompt_len
133+
metrics[common_metrics.NUM_OUTPUT_TOKENS] = tokens_received
134+
metrics[common_metrics.NUM_INPUT_TOKENS] = prompt_len
135+
136+
return metrics, generated_text, request_config

0 commit comments

Comments
 (0)