@@ -199,6 +199,7 @@ def split_json_by_tp_pp(
199199
200200 return saved_paths
201201
202+
202203def _add_limit_line (fig , y_value , label ):
203204 # Visible dashed line + annotation
204205 fig .add_hline (
@@ -211,18 +212,28 @@ def _add_limit_line(fig, y_value, label):
211212 # Optional: add a legend item (as a transparent helper trace)
212213 if plot and plotly_found :
213214 import plotly .graph_objects as go
214- fig .add_trace (go .Scatter (
215- x = [None ], y = [None ],
216- mode = "lines" ,
217- line = dict (dash = "dash" ,
218- color = "red" if "ttft" in label .lower () else "blue" ),
219- name = f"{ label } "
220- ))
215+
216+ fig .add_trace (
217+ go .Scatter (
218+ x = [None ],
219+ y = [None ],
220+ mode = "lines" ,
221+ line = dict (
222+ dash = "dash" , color = "red" if "ttft" in label .lower () else "blue"
223+ ),
224+ name = f"{ label } " ,
225+ )
226+ )
221227
222228
223229def _find_concurrency_col (df : pd .DataFrame ) -> str :
224- for c in ["# of max concurrency." , "# of max concurrency" , "Max Concurrency" ,
225- "max_concurrency" , "Concurrency" ]:
230+ for c in [
231+ "# of max concurrency." ,
232+ "# of max concurrency" ,
233+ "Max Concurrency" ,
234+ "max_concurrency" ,
235+ "Concurrency" ,
236+ ]:
226237 if c in df .columns :
227238 return c
228239 # Fallback: guess an integer-like column (harmless if unused)
@@ -231,15 +242,26 @@ def _find_concurrency_col(df: pd.DataFrame) -> str:
231242 return c
232243 return "# of max concurrency."
233244
234- def _highlight_threshold (df : pd .DataFrame , threshold : float ) -> "pd.io.formats.style.Styler" :
245+
246+ def _highlight_threshold (
247+ df : pd .DataFrame , threshold : float
248+ ) -> "pd.io.formats.style.Styler" :
235249 """Highlight numeric per-configuration columns with value <= threshold."""
236250 conc_col = _find_concurrency_col (df )
237- key_cols = [c for c in ["Model" , "Dataset Name" , "Input Len" , "Output Len" , conc_col ] if c in df .columns ]
238- conf_cols = [c for c in df .columns if c not in key_cols and not str (c ).startswith ("Ratio" )]
251+ key_cols = [
252+ c
253+ for c in ["Model" , "Dataset Name" , "Input Len" , "Output Len" , conc_col ]
254+ if c in df .columns
255+ ]
256+ conf_cols = [
257+ c for c in df .columns if c not in key_cols and not str (c ).startswith ("Ratio" )
258+ ]
239259 conf_cols = [c for c in conf_cols if pd .api .types .is_numeric_dtype (df [c ])]
240260 return df .style .map (
241- lambda v : "background-color:#e6ffe6;font-weight:bold;" if pd .notna (v ) and v <= threshold else "" ,
242- subset = conf_cols
261+ lambda v : "background-color:#e6ffe6;font-weight:bold;"
262+ if pd .notna (v ) and v <= threshold
263+ else "" ,
264+ subset = conf_cols ,
243265 )
244266
245267
@@ -271,11 +293,18 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
271293 default = "p99" ,
272294 help = "take median|p99 for latency like TTFT/TPOT" ,
273295 )
274- parser .add_argument ("--ttft-max-ms" , type = float , default = 3000.0 ,
275- help = "Reference limit for TTFT plots (ms)" )
276- parser .add_argument ("--tpot-max-ms" , type = float , default = 100.0 ,
277- help = "Reference limit for TPOT plots (ms)" )
278-
296+ parser .add_argument (
297+ "--ttft-max-ms" ,
298+ type = float ,
299+ default = 3000.0 ,
300+ help = "Reference limit for TTFT plots (ms)" ,
301+ )
302+ parser .add_argument (
303+ "--tpot-max-ms" ,
304+ type = float ,
305+ default = 100.0 ,
306+ help = "Reference limit for TPOT plots (ms)" ,
307+ )
279308
280309 args = parser .parse_args ()
281310
@@ -342,29 +371,39 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
342371 f"Expected subset: { filtered_info_cols } , "
343372 f"but DataFrame has: { list (output_df .columns )} "
344373 )
345- #output_df_sorted = output_df.sort_values(by=existing_group_cols)
374+ # output_df_sorted = output_df.sort_values(by=existing_group_cols)
346375 output_df_sorted = output_df .sort_values (by = args .xaxis )
347376 output_groups = output_df_sorted .groupby (existing_group_cols , dropna = False )
348377 for name , group in output_groups :
349- group_name = "," .join (map (str , name )).replace ("," , "_" ).replace ("/" ,"-" )
378+ group_name = (
379+ "," .join (map (str , name )).replace ("," , "_" ).replace ("/" , "-" )
380+ )
350381 group_html_name = "perf_comparison_" + group_name + ".html"
351382
352383 metric_name = str (data_cols_to_compare [i ]).lower ()
353384 if "tok/s" in metric_name :
354385 html = group .to_html ()
355386 elif "ttft" in metric_name :
356- styler = (
357- _highlight_threshold (group , args .ttft_max_ms )
358- .format ({c : "{:.2f}" for c in group .select_dtypes ("number" ).columns }, na_rep = "—" )
387+ styler = _highlight_threshold (group , args .ttft_max_ms ).format (
388+ {c : "{:.2f}" for c in group .select_dtypes ("number" ).columns },
389+ na_rep = "—" ,
390+ )
391+ html = styler .to_html (
392+ table_attributes = 'border="1" class="dataframe"'
359393 )
360- html = styler .to_html (table_attributes = 'border="1" class="dataframe"' )
361- elif "tpot" in metric_name or "median" in metric_name or "p99" in metric_name :
362- styler = (
363- _highlight_threshold (group , args .tpot_max_ms )
364- .format ({c : "{:.2f}" for c in group .select_dtypes ("number" ).columns }, na_rep = "—" )
394+ elif (
395+ "tpot" in metric_name
396+ or "median" in metric_name
397+ or "p99" in metric_name
398+ ):
399+ styler = _highlight_threshold (group , args .tpot_max_ms ).format (
400+ {c : "{:.2f}" for c in group .select_dtypes ("number" ).columns },
401+ na_rep = "—" ,
365402 )
366- html = styler .to_html (table_attributes = 'border="1" class="dataframe"' )
367-
403+ html = styler .to_html (
404+ table_attributes = 'border="1" class="dataframe"'
405+ )
406+
368407 text_file .write (html_msgs_for_data_cols [i ])
369408 text_file .write (html )
370409 with open (group_html_name , "a" ) as sub_text_file :
@@ -382,7 +421,9 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
382421 var_name = "Configuration" ,
383422 value_name = data_cols_to_compare [i ],
384423 )
385- title = data_cols_to_compare [i ] + " vs " + info_cols [y_axis_index ]
424+ title = (
425+ data_cols_to_compare [i ] + " vs " + info_cols [y_axis_index ]
426+ )
386427 # Create Plotly line chart
387428 fig = px .line (
388429 df_melted ,
@@ -396,9 +437,17 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
396437 # ---- Add threshold lines based on metric name ----
397438 if "ttft" in metric_name :
398439 _add_limit_line (fig , args .ttft_max_ms , "TTFT limit" )
399- elif "tpot" in metric_name or "median" in metric_name or "p99" in metric_name :
440+ elif (
441+ "tpot" in metric_name
442+ or "median" in metric_name
443+ or "p99" in metric_name
444+ ):
400445 _add_limit_line (fig , args .tpot_max_ms , "TPOT limit" )
401446
402447 # Export to HTML
403- text_file .write (fig .to_html (full_html = True , include_plotlyjs = "cdn" ))
404- sub_text_file .write (fig .to_html (full_html = True , include_plotlyjs = "cdn" ))
448+ text_file .write (
449+ fig .to_html (full_html = True , include_plotlyjs = "cdn" )
450+ )
451+ sub_text_file .write (
452+ fig .to_html (full_html = True , include_plotlyjs = "cdn" )
453+ )
0 commit comments