Skip to content

Commit e515299

Browse files
committed
fix for pre-commit
1 parent 1008e23 commit e515299

File tree

1 file changed

+84
-35
lines changed

1 file changed

+84
-35
lines changed

.buildkite/nightly-benchmarks/scripts/compare-json-results.py

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def split_json_by_tp_pp(
199199

200200
return saved_paths
201201

202+
202203
def _add_limit_line(fig, y_value, label):
203204
# Visible dashed line + annotation
204205
fig.add_hline(
@@ -211,18 +212,28 @@ def _add_limit_line(fig, y_value, label):
211212
# Optional: add a legend item (as a transparent helper trace)
212213
if plot and plotly_found:
213214
import plotly.graph_objects as go
214-
fig.add_trace(go.Scatter(
215-
x=[None], y=[None],
216-
mode="lines",
217-
line=dict(dash="dash",
218-
color="red" if "ttft" in label.lower() else "blue"),
219-
name=f"{label}"
220-
))
215+
216+
fig.add_trace(
217+
go.Scatter(
218+
x=[None],
219+
y=[None],
220+
mode="lines",
221+
line=dict(
222+
dash="dash", color="red" if "ttft" in label.lower() else "blue"
223+
),
224+
name=f"{label}",
225+
)
226+
)
221227

222228

223229
def _find_concurrency_col(df: pd.DataFrame) -> str:
224-
for c in ["# of max concurrency.", "# of max concurrency", "Max Concurrency",
225-
"max_concurrency", "Concurrency"]:
230+
for c in [
231+
"# of max concurrency.",
232+
"# of max concurrency",
233+
"Max Concurrency",
234+
"max_concurrency",
235+
"Concurrency",
236+
]:
226237
if c in df.columns:
227238
return c
228239
# Fallback: guess an integer-like column (harmless if unused)
@@ -231,15 +242,26 @@ def _find_concurrency_col(df: pd.DataFrame) -> str:
231242
return c
232243
return "# of max concurrency."
233244

234-
def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.style.Styler":
245+
246+
def _highlight_threshold(
247+
df: pd.DataFrame, threshold: float
248+
) -> "pd.io.formats.style.Styler":
235249
"""Highlight numeric per-configuration columns with value <= threshold."""
236250
conc_col = _find_concurrency_col(df)
237-
key_cols = [c for c in ["Model", "Dataset Name", "Input Len", "Output Len", conc_col] if c in df.columns]
238-
conf_cols = [c for c in df.columns if c not in key_cols and not str(c).startswith("Ratio")]
251+
key_cols = [
252+
c
253+
for c in ["Model", "Dataset Name", "Input Len", "Output Len", conc_col]
254+
if c in df.columns
255+
]
256+
conf_cols = [
257+
c for c in df.columns if c not in key_cols and not str(c).startswith("Ratio")
258+
]
239259
conf_cols = [c for c in conf_cols if pd.api.types.is_numeric_dtype(df[c])]
240260
return df.style.map(
241-
lambda v: "background-color:#e6ffe6;font-weight:bold;" if pd.notna(v) and v <= threshold else "",
242-
subset=conf_cols
261+
lambda v: "background-color:#e6ffe6;font-weight:bold;"
262+
if pd.notna(v) and v <= threshold
263+
else "",
264+
subset=conf_cols,
243265
)
244266

245267

@@ -271,11 +293,18 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
271293
default="p99",
272294
help="take median|p99 for latency like TTFT/TPOT",
273295
)
274-
parser.add_argument("--ttft-max-ms", type=float, default=3000.0,
275-
help="Reference limit for TTFT plots (ms)")
276-
parser.add_argument("--tpot-max-ms", type=float, default=100.0,
277-
help="Reference limit for TPOT plots (ms)")
278-
296+
parser.add_argument(
297+
"--ttft-max-ms",
298+
type=float,
299+
default=3000.0,
300+
help="Reference limit for TTFT plots (ms)",
301+
)
302+
parser.add_argument(
303+
"--tpot-max-ms",
304+
type=float,
305+
default=100.0,
306+
help="Reference limit for TPOT plots (ms)",
307+
)
279308

280309
args = parser.parse_args()
281310

@@ -342,29 +371,39 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
342371
f"Expected subset: {filtered_info_cols}, "
343372
f"but DataFrame has: {list(output_df.columns)}"
344373
)
345-
#output_df_sorted = output_df.sort_values(by=existing_group_cols)
374+
# output_df_sorted = output_df.sort_values(by=existing_group_cols)
346375
output_df_sorted = output_df.sort_values(by=args.xaxis)
347376
output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False)
348377
for name, group in output_groups:
349-
group_name = ",".join(map(str, name)).replace(",", "_").replace("/","-")
378+
group_name = (
379+
",".join(map(str, name)).replace(",", "_").replace("/", "-")
380+
)
350381
group_html_name = "perf_comparison_" + group_name + ".html"
351382

352383
metric_name = str(data_cols_to_compare[i]).lower()
353384
if "tok/s" in metric_name:
354385
html = group.to_html()
355386
elif "ttft" in metric_name:
356-
styler = (
357-
_highlight_threshold(group, args.ttft_max_ms)
358-
.format({c: "{:.2f}" for c in group.select_dtypes("number").columns}, na_rep="—")
387+
styler = _highlight_threshold(group, args.ttft_max_ms).format(
388+
{c: "{:.2f}" for c in group.select_dtypes("number").columns},
389+
na_rep="—",
390+
)
391+
html = styler.to_html(
392+
table_attributes='border="1" class="dataframe"'
359393
)
360-
html = styler.to_html(table_attributes='border="1" class="dataframe"')
361-
elif "tpot" in metric_name or "median" in metric_name or "p99" in metric_name:
362-
styler = (
363-
_highlight_threshold(group, args.tpot_max_ms)
364-
.format({c: "{:.2f}" for c in group.select_dtypes("number").columns}, na_rep="—")
394+
elif (
395+
"tpot" in metric_name
396+
or "median" in metric_name
397+
or "p99" in metric_name
398+
):
399+
styler = _highlight_threshold(group, args.tpot_max_ms).format(
400+
{c: "{:.2f}" for c in group.select_dtypes("number").columns},
401+
na_rep="—",
365402
)
366-
html = styler.to_html(table_attributes='border="1" class="dataframe"')
367-
403+
html = styler.to_html(
404+
table_attributes='border="1" class="dataframe"'
405+
)
406+
368407
text_file.write(html_msgs_for_data_cols[i])
369408
text_file.write(html)
370409
with open(group_html_name, "a") as sub_text_file:
@@ -382,7 +421,9 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
382421
var_name="Configuration",
383422
value_name=data_cols_to_compare[i],
384423
)
385-
title = data_cols_to_compare[i] + " vs " + info_cols[y_axis_index]
424+
title = (
425+
data_cols_to_compare[i] + " vs " + info_cols[y_axis_index]
426+
)
386427
# Create Plotly line chart
387428
fig = px.line(
388429
df_melted,
@@ -396,9 +437,17 @@ def _highlight_threshold(df: pd.DataFrame, threshold: float) -> "pd.io.formats.s
396437
# ---- Add threshold lines based on metric name ----
397438
if "ttft" in metric_name:
398439
_add_limit_line(fig, args.ttft_max_ms, "TTFT limit")
399-
elif "tpot" in metric_name or "median" in metric_name or "p99" in metric_name :
440+
elif (
441+
"tpot" in metric_name
442+
or "median" in metric_name
443+
or "p99" in metric_name
444+
):
400445
_add_limit_line(fig, args.tpot_max_ms, "TPOT limit")
401446

402447
# Export to HTML
403-
text_file.write(fig.to_html(full_html=True, include_plotlyjs="cdn"))
404-
sub_text_file.write(fig.to_html(full_html=True, include_plotlyjs="cdn"))
448+
text_file.write(
449+
fig.to_html(full_html=True, include_plotlyjs="cdn")
450+
)
451+
sub_text_file.write(
452+
fig.to_html(full_html=True, include_plotlyjs="cdn")
453+
)

0 commit comments

Comments
 (0)