77
88import pandas as pd
99
10+ pd .options .display .float_format = "{:.2f}" .format
1011plotly_found = util .find_spec ("plotly.express" ) is not None
1112
1213
@@ -109,7 +110,10 @@ def compare_data_columns(
109110 if len (compare_frames ) >= 2 :
110111 base = compare_frames [0 ]
111112 current = compare_frames [- 1 ]
112- ratio = current / base
113+ if "P99" in data_column or "Median" in data_column :
114+ ratio = base / current # for latency
115+ else :
116+ ratio = current / base
113117 ratio = ratio .mask (base == 0 ) # avoid inf when baseline is 0
114118 ratio .name = f"Ratio 1 vs { len (compare_frames )} "
115119 frames .append (ratio )
@@ -199,6 +203,71 @@ def split_json_by_tp_pp(
199203 return saved_paths
200204
201205
206+ def _add_limit_line (fig , y_value , label ):
207+ # Visible dashed line + annotation
208+ fig .add_hline (
209+ y = y_value ,
210+ line_dash = "dash" ,
211+ line_color = "red" if "ttft" in label .lower () else "blue" ,
212+ annotation_text = f"{ label } : { y_value } ms" ,
213+ annotation_position = "top left" ,
214+ )
215+ # Optional: add a legend item (as a transparent helper trace)
216+ if plot and plotly_found :
217+ import plotly .graph_objects as go
218+
219+ fig .add_trace (
220+ go .Scatter (
221+ x = [None ],
222+ y = [None ],
223+ mode = "lines" ,
224+ line = dict (
225+ dash = "dash" , color = "red" if "ttft" in label .lower () else "blue"
226+ ),
227+ name = f"{ label } " ,
228+ )
229+ )
230+
231+
232+ def _find_concurrency_col (df : pd .DataFrame ) -> str :
233+ for c in [
234+ "# of max concurrency." ,
235+ "# of max concurrency" ,
236+ "Max Concurrency" ,
237+ "max_concurrency" ,
238+ "Concurrency" ,
239+ ]:
240+ if c in df .columns :
241+ return c
242+ # Fallback: guess an integer-like column (harmless if unused)
243+ for c in df .columns :
244+ if df [c ].dtype .kind in "iu" and df [c ].nunique () > 1 and df [c ].min () >= 1 :
245+ return c
246+ return "# of max concurrency."
247+
248+
249+ def _highlight_threshold (
250+ df : pd .DataFrame , threshold : float
251+ ) -> "pd.io.formats.style.Styler" :
252+ """Highlight numeric per-configuration columns with value <= threshold."""
253+ conc_col = _find_concurrency_col (df )
254+ key_cols = [
255+ c
256+ for c in ["Model" , "Dataset Name" , "Input Len" , "Output Len" , conc_col ]
257+ if c in df .columns
258+ ]
259+ conf_cols = [
260+ c for c in df .columns if c not in key_cols and not str (c ).startswith ("Ratio" )
261+ ]
262+ conf_cols = [c for c in conf_cols if pd .api .types .is_numeric_dtype (df [c ])]
263+ return df .style .map (
264+ lambda v : "background-color:#e6ffe6;font-weight:bold;"
265+ if pd .notna (v ) and v <= threshold
266+ else "" ,
267+ subset = conf_cols ,
268+ )
269+
270+
202271if __name__ == "__main__" :
203272 parser = argparse .ArgumentParser ()
204273 parser .add_argument (
@@ -220,6 +289,26 @@ def split_json_by_tp_pp(
220289 default = "# of max concurrency." ,
221290 help = "column name to use as X Axis in comparison graph" ,
222291 )
292+ parser .add_argument (
293+ "-l" ,
294+ "--latency" ,
295+ type = str ,
296+ default = "p99" ,
297+ help = "take median|p99 for latency like TTFT/TPOT" ,
298+ )
299+ parser .add_argument (
300+ "--ttft-max-ms" ,
301+ type = float ,
302+ default = 3000.0 ,
303+ help = "Reference limit for TTFT plots (ms)" ,
304+ )
305+ parser .add_argument (
306+ "--tpot-max-ms" ,
307+ type = float ,
308+ default = 100.0 ,
309+ help = "Reference limit for TPOT plots (ms)" ,
310+ )
311+
223312 args = parser .parse_args ()
224313
225314 drop_column = "P99"
@@ -234,12 +323,22 @@ def split_json_by_tp_pp(
234323 "# of max concurrency." ,
235324 "qps" ,
236325 ]
237- data_cols_to_compare = ["Output Tput (tok/s)" , "Median TTFT (ms)" , "Median" ]
238- html_msgs_for_data_cols = [
239- "Compare Output Tokens /n" ,
240- "Median TTFT /n" ,
241- "Median TPOT /n" ,
242- ]
326+
327+ if "median" in args .latency :
328+ data_cols_to_compare = ["Output Tput (tok/s)" , "Median TTFT (ms)" , "Median" ]
329+ html_msgs_for_data_cols = [
330+ "Compare Output Tokens /n" ,
331+ "Median TTFT /n" ,
332+ "Median TPOT /n" ,
333+ ]
334+ drop_column = "P99"
335+ elif "p99" in args .latency :
336+ data_cols_to_compare = ["Output Tput (tok/s)" , "P99 TTFT (ms)" , "P99" ]
337+ html_msgs_for_data_cols = [
338+ "Compare Output Tokens /n" ,
339+ "P99 TTFT /n" ,
340+ "P99 TPOT /n" ,
341+ ]
243342
244343 if len (args .file ) == 1 :
245344 files = split_json_by_tp_pp (args .file [0 ], output_root = "splits" )
@@ -275,33 +374,83 @@ def split_json_by_tp_pp(
275374 f"Expected subset: { filtered_info_cols } , "
276375 f"but DataFrame has: { list (output_df .columns )} "
277376 )
278- output_df_sorted = output_df .sort_values (by = existing_group_cols )
377+ # output_df_sorted = output_df.sort_values(by=existing_group_cols)
378+ output_df_sorted = output_df .sort_values (by = args .xaxis )
279379 output_groups = output_df_sorted .groupby (existing_group_cols , dropna = False )
280380 for name , group in output_groups :
281- html = group .to_html ()
282- text_file .write (html_msgs_for_data_cols [i ])
283- text_file .write (html )
284-
285- if plot and plotly_found :
286- import plotly .express as px
287-
288- df = group [raw_data_cols ]
289- df_sorted = df .sort_values (by = info_cols [y_axis_index ])
290- # Melt DataFrame for plotting
291- df_melted = df_sorted .melt (
292- id_vars = info_cols [y_axis_index ],
293- var_name = "Configuration" ,
294- value_name = data_cols_to_compare [i ],
381+ group_name = (
382+ "," .join (map (str , name )).replace ("," , "_" ).replace ("/" , "-" )
383+ )
384+ group_html_name = "perf_comparison_" + group_name + ".html"
385+
386+ metric_name = str (data_cols_to_compare [i ]).lower ()
387+ if "tok/s" in metric_name :
388+ html = group .to_html ()
389+ elif "ttft" in metric_name :
390+ styler = _highlight_threshold (group , args .ttft_max_ms ).format (
391+ {c : "{:.2f}" for c in group .select_dtypes ("number" ).columns },
392+ na_rep = "—" ,
393+ )
394+ html = styler .to_html (
395+ table_attributes = 'border="1" class="dataframe"'
396+ )
397+ elif (
398+ "tpot" in metric_name
399+ or "median" in metric_name
400+ or "p99" in metric_name
401+ ):
402+ styler = _highlight_threshold (group , args .tpot_max_ms ).format (
403+ {c : "{:.2f}" for c in group .select_dtypes ("number" ).columns },
404+ na_rep = "—" ,
295405 )
296- title = data_cols_to_compare [i ] + " vs " + info_cols [y_axis_index ]
297- # Create Plotly line chart
298- fig = px .line (
299- df_melted ,
300- x = info_cols [y_axis_index ],
301- y = data_cols_to_compare [i ],
302- color = "Configuration" ,
303- title = title ,
304- markers = True ,
406+ html = styler .to_html (
407+ table_attributes = 'border="1" class="dataframe"'
305408 )
306- # Export to HTML
307- text_file .write (fig .to_html (full_html = True , include_plotlyjs = "cdn" ))
409+
410+ text_file .write (html_msgs_for_data_cols [i ])
411+ text_file .write (html )
412+ with open (group_html_name , "a+" ) as sub_text_file :
413+ sub_text_file .write (html_msgs_for_data_cols [i ])
414+ sub_text_file .write (html )
415+
416+ if plot and plotly_found :
417+ import plotly .express as px
418+
419+ df = group [raw_data_cols ]
420+ df_sorted = df .sort_values (by = info_cols [y_axis_index ])
421+ # Melt DataFrame for plotting
422+ df_melted = df_sorted .melt (
423+ id_vars = info_cols [y_axis_index ],
424+ var_name = "Configuration" ,
425+ value_name = data_cols_to_compare [i ],
426+ )
427+ title = (
428+ data_cols_to_compare [i ] + " vs " + info_cols [y_axis_index ]
429+ )
430+ # Create Plotly line chart
431+ fig = px .line (
432+ df_melted ,
433+ x = info_cols [y_axis_index ],
434+ y = data_cols_to_compare [i ],
435+ color = "Configuration" ,
436+ title = title ,
437+ markers = True ,
438+ )
439+
440+ # ---- Add threshold lines based on metric name ----
441+ if "ttft" in metric_name :
442+ _add_limit_line (fig , args .ttft_max_ms , "TTFT limit" )
443+ elif (
444+ "tpot" in metric_name
445+ or "median" in metric_name
446+ or "p99" in metric_name
447+ ):
448+ _add_limit_line (fig , args .tpot_max_ms , "TPOT limit" )
449+
450+ # Export to HTML
451+ text_file .write (
452+ fig .to_html (full_html = True , include_plotlyjs = "cdn" )
453+ )
454+ sub_text_file .write (
455+ fig .to_html (full_html = True , include_plotlyjs = "cdn" )
456+ )
0 commit comments