@@ -23,7 +23,7 @@ def run_test_cases(
2323 timestamp_dir ,
2424 server_url ,
2525 tokenizer_path ,
26- hit_rate
26+ hit_rate ,
2727):
2828 print (f"[INFO] Total { len (mean_input_tokens )} test cases to be executed" )
2929 all_summaries = []
@@ -35,20 +35,23 @@ def run_test_cases(
3535 env .pop ("https_proxy" , None )
3636
3737 for i , (
38- mean_input ,
39- mean_output ,
40- max_completed ,
41- concurrent ,
42- additional_sampling_params ,
43- hit_rate_val
44- ) in enumerate (zip (
45- mean_input_tokens ,
46- mean_output_tokens ,
47- max_num_completed_requests ,
48- concurrent_requests ,
38+ mean_input ,
39+ mean_output ,
40+ max_completed ,
41+ concurrent ,
4942 additional_sampling_params ,
50- hit_rate
51- ), start = 1 ):
43+ hit_rate_val ,
44+ ) in enumerate (
45+ zip (
46+ mean_input_tokens ,
47+ mean_output_tokens ,
48+ max_num_completed_requests ,
49+ concurrent_requests ,
50+ additional_sampling_params ,
51+ hit_rate ,
52+ ),
53+ start = 1 ,
54+ ):
5255 # for i, case in enumerate(mean_input_tokens):
5356 print (f"\n >>> Executing test case { i } <<<" )
5457 reset_prefill_cache (env , server_url )
@@ -130,12 +133,13 @@ def run_test_cases(
130133
131134
132135def inference_results (
133- mean_input_tokens ,
134- mean_output_tokens ,
135- max_num_completed_requests ,
136- concurrent_requests ,
137- additional_sampling_params ,
138- hit_rate ):
136+ mean_input_tokens ,
137+ mean_output_tokens ,
138+ max_num_completed_requests ,
139+ concurrent_requests ,
140+ additional_sampling_params ,
141+ hit_rate ,
142+ ):
139143 config_file = Path (__file__ ).parent .parent .parent / "config.yaml"
140144 print ("[INFO] Initialization complete, starting main process" )
141145 print (f"[INFO] Reading configuration file: { config_file } " )
@@ -144,8 +148,12 @@ def inference_results(
144148 llm_api = config .get ("llm_connection" , {}).get ("llm_api" , "openai" )
145149 model = config .get ("llm_connection" , {}).get ("model" , "" )
146150 test_timeout_s = config .get ("llm_connection" , {}).get ("test_timeout_s" , 60000 )
147- stddev_input_tokens = config .get ("llm_connection" , {}).get ("stddev_input_tokens" , 0 )
148- stddev_output_tokens = config .get ("llm_connection" , {}).get ("stddev_output_tokens" , 0 )
151+ stddev_input_tokens = config .get ("llm_connection" , {}).get (
152+ "stddev_input_tokens" , 0
153+ )
154+ stddev_output_tokens = config .get ("llm_connection" , {}).get (
155+ "stddev_output_tokens" , 0
156+ )
149157 timestamp_dir = Path ("results" )
150158 timestamp_dir .mkdir (parents = True , exist_ok = True )
151159 server_url = config .get ("llm_connection" , {}).get ("server_url" , "" )
@@ -166,12 +174,12 @@ def inference_results(
166174 timestamp_dir ,
167175 server_url ,
168176 tokenizer_path ,
169- hit_rate
177+ hit_rate ,
170178 )
171179 total = len (mean_input_tokens )
172180 print (
173181 f"\n [INFO] All tests completed! Success: { total - len (failed_cases )} /{ total } "
174182 )
175183 if failed_cases :
176184 print (f"[WARN] Failed case indices: { failed_cases } " )
177- return all_summaries
185+ return all_summaries
0 commit comments