Skip to content

Commit

Permalink
fix: Allow to ignore certain assertion errors (#545)
Browse files Browse the repository at this point in the history
  • Loading branch information
robinholzi authored Jun 24, 2024
1 parent 97a2b5f commit 42354cf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
27 changes: 21 additions & 6 deletions analytics/tools/aggregate_runs/core_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from copy import deepcopy
from pathlib import Path

import pandas as pd

from analytics.app.data.transform import dfs_models_and_evals, logs_dataframe
from analytics.tools.aggregate_runs.dir_utils import load_multiple_logfiles
from analytics.tools.aggregate_runs.pipeline_equivalence import assert_pipeline_equivalence
Expand Down Expand Up @@ -73,18 +75,31 @@ def aggregate_eval_metrics(df_eval_single: pd.DataFrame, logs: list[PipelineLogs
["model_idx", "eval_handler", "dataset_id", "interval_start", "interval_end", "metric"]
)

for size in groups.size():
assert size == len(logs), "Wrong primary key"
sizes = groups.agg(size=("model_idx", "size")).reset_index()
if len(sizes["size"].unique()) != 1 or int(sizes[0]) != len(logs):
logging.warning(f"\n{sizes[sizes['size'] != len(logs)]}")
logging.warning(
"The number of records in every group is not equal to the number of logs. "
"This might be due to missing records in the logs or a wrong grouping primary key. "
"If only a few records show less than the expected number of logs, you might want to "
"ignore and continue by pressing any key."
)
breakpoint()

aggregated_metrics = groups.agg(
agg_value=("value", "mean"), id_model_list=("id_model", lambda x: list(x))
).reset_index()

# sanity check: per aggregated row we find len(logs) unique id_model
assert all(
len(row[1]["id_model_list"]) == len(logs)
for row in aggregated_metrics[["model_idx", "id_model_list"]].iterrows()
)
agg = aggregated_metrics[["model_idx", "id_model_list"]]
agg["num_models"] = agg["id_model_list"].apply(len)
breaking_rows = agg[agg["num_models"] != len(logs)]
if breaking_rows.shape[0] > 0:
logging.warning(f"\n{breaking_rows}")
logging.warning(
"The number of unique id_model in the aggregated metrics is not equal to the number of logs. Please verify."
)
breakpoint()

if DEBUGGING_MODE:
# print(aggregated_metrics[["model_idx", "id_model_list"]])
Expand Down
4 changes: 3 additions & 1 deletion analytics/tools/aggregate_runs/dir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def group_pipelines_by_name(pipeline_logs_directory: Path) -> dict[str, list[Pat
pipeline_logs_directory / d for d in os.listdir(pipeline_logs_directory) if str(d).startswith("pipeline_")
]

pipeline_names: list[tuple[Path, str]] = [(d, (d / ".name").read_text()) for d in pipeline_directories if (d / "pipeline.log").exists()]
pipeline_names: list[tuple[Path, str]] = [
(d, (d / ".name").read_text()) for d in pipeline_directories if (d / "pipeline.log").exists()
]

pipeline_groups = {name: [d for d, n in pipeline_names if n == name] for name in set(n for _, n in pipeline_names)}
return pipeline_groups
Expand Down

0 comments on commit 42354cf

Please sign in to comment.