Skip to content

Commit f6b84d4

Browse files
committed
Adapted to pytest framework
1 parent bfa91da commit f6b84d4

File tree

3 files changed

+44
-33
lines changed

3 files changed

+44
-33
lines changed

test/common/llmperf/run_inference.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def run_test_cases(
2323
timestamp_dir,
2424
server_url,
2525
tokenizer_path,
26-
hit_rate
26+
hit_rate,
2727
):
2828
print(f"[INFO] Total {len(mean_input_tokens)} test cases to be executed")
2929
all_summaries = []
@@ -35,20 +35,23 @@ def run_test_cases(
3535
env.pop("https_proxy", None)
3636

3737
for i, (
38-
mean_input,
39-
mean_output,
40-
max_completed,
41-
concurrent,
42-
additional_sampling_params,
43-
hit_rate_val
44-
) in enumerate(zip(
45-
mean_input_tokens,
46-
mean_output_tokens,
47-
max_num_completed_requests,
48-
concurrent_requests,
38+
mean_input,
39+
mean_output,
40+
max_completed,
41+
concurrent,
4942
additional_sampling_params,
50-
hit_rate
51-
), start=1):
43+
hit_rate_val,
44+
) in enumerate(
45+
zip(
46+
mean_input_tokens,
47+
mean_output_tokens,
48+
max_num_completed_requests,
49+
concurrent_requests,
50+
additional_sampling_params,
51+
hit_rate,
52+
),
53+
start=1,
54+
):
5255
# for i, case in enumerate(mean_input_tokens):
5356
print(f"\n>>> Executing test case {i} <<<")
5457
reset_prefill_cache(env, server_url)
@@ -130,12 +133,13 @@ def run_test_cases(
130133

131134

132135
def inference_results(
133-
mean_input_tokens,
134-
mean_output_tokens,
135-
max_num_completed_requests,
136-
concurrent_requests,
137-
additional_sampling_params,
138-
hit_rate):
136+
mean_input_tokens,
137+
mean_output_tokens,
138+
max_num_completed_requests,
139+
concurrent_requests,
140+
additional_sampling_params,
141+
hit_rate,
142+
):
139143
config_file = Path(__file__).parent.parent.parent / "config.yaml"
140144
print("[INFO] Initialization complete, starting main process")
141145
print(f"[INFO] Reading configuration file: {config_file}")
@@ -144,8 +148,12 @@ def inference_results(
144148
llm_api = config.get("llm_connection", {}).get("llm_api", "openai")
145149
model = config.get("llm_connection", {}).get("model", "")
146150
test_timeout_s = config.get("llm_connection", {}).get("test_timeout_s", 60000)
147-
stddev_input_tokens = config.get("llm_connection", {}).get("stddev_input_tokens", 0)
148-
stddev_output_tokens = config.get("llm_connection", {}).get("stddev_output_tokens", 0)
151+
stddev_input_tokens = config.get("llm_connection", {}).get(
152+
"stddev_input_tokens", 0
153+
)
154+
stddev_output_tokens = config.get("llm_connection", {}).get(
155+
"stddev_output_tokens", 0
156+
)
149157
timestamp_dir = Path("results")
150158
timestamp_dir.mkdir(parents=True, exist_ok=True)
151159
server_url = config.get("llm_connection", {}).get("server_url", "")
@@ -166,12 +174,12 @@ def inference_results(
166174
timestamp_dir,
167175
server_url,
168176
tokenizer_path,
169-
hit_rate
177+
hit_rate,
170178
)
171179
total = len(mean_input_tokens)
172180
print(
173181
f"\n[INFO] All tests completed! Success: {total - len(failed_cases)}/{total}"
174182
)
175183
if failed_cases:
176184
print(f"[WARN] Failed case indices: {failed_cases}")
177-
return all_summaries
185+
return all_summaries

test/common/llmperf/utils/openai_chat_completions_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
from common.llmperf.utils.models import RequestConfig
1212

1313
config_file = Path(__file__).parent.parent.parent.parent / "config.yaml"
14-
with open(config_file, 'r', encoding='utf-8') as f:
14+
with open(config_file, "r", encoding="utf-8") as f:
1515
config = yaml.safe_load(f)
1616
stream = config.get("llm_connection", {}).get("stream", True)
1717
ignore_eos = config.get("llm_connection", {}).get("ignore_eos", True)
1818
timeout = config.get("llm_connection", {}).get("timeout", 180)
1919

20+
2021
class OpenAIChatCompletionsClient:
2122
"""
2223
used for sending HTTP requests, receiving token streams, measuring latency, etc.

test/suites/E2E/test_uc_performance.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from common.capture_utils import export_vars
33
from common.llmperf.run_inference import inference_results
44

5+
56
@pytest.mark.parametrize("mean_input_tokens", [[2000, 3000]])
67
@pytest.mark.parametrize("mean_output_tokens", [[200, 500]])
78
@pytest.mark.parametrize("max_num_completed_requests", [[8, 4]])
@@ -10,21 +11,22 @@
1011
@pytest.mark.parametrize("hit_rate", [[0, 50]])
1112
@pytest.mark.feature("uc_performance_test")
1213
@export_vars
13-
1414
def test_performance(
15-
mean_input_tokens,
16-
mean_output_tokens,
17-
max_num_completed_requests,
18-
concurrent_requests,
19-
additional_sampling_params,
20-
hit_rate):
15+
mean_input_tokens,
16+
mean_output_tokens,
17+
max_num_completed_requests,
18+
concurrent_requests,
19+
additional_sampling_params,
20+
hit_rate,
21+
):
2122
all_summaries = inference_results(
2223
mean_input_tokens,
2324
mean_output_tokens,
2425
max_num_completed_requests,
2526
concurrent_requests,
2627
additional_sampling_params,
27-
hit_rate)
28+
hit_rate,
29+
)
2830
failed_cases = []
2931

3032
value_lists = {

0 commit comments

Comments
 (0)