Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allow to ignore certain assertion errors #545

Merged
merged 1 commit into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions analytics/tools/aggregate_runs/core_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from copy import deepcopy
from pathlib import Path

import pandas as pd

from analytics.app.data.transform import dfs_models_and_evals, logs_dataframe
from analytics.tools.aggregate_runs.dir_utils import load_multiple_logfiles
from analytics.tools.aggregate_runs.pipeline_equivalence import assert_pipeline_equivalence
Expand Down Expand Up @@ -73,18 +75,31 @@ def aggregate_eval_metrics(df_eval_single: pd.DataFrame, logs: list[PipelineLogs
["model_idx", "eval_handler", "dataset_id", "interval_start", "interval_end", "metric"]
)

for size in groups.size():
assert size == len(logs), "Wrong primary key"
sizes = groups.agg(size=("model_idx", "size")).reset_index()
if len(sizes["size"].unique()) != 1 or int(sizes[0]) != len(logs):
logging.warning(f"\n{sizes[sizes['size'] != len(logs)]}")
logging.warning(
"The number of records in every group is not equal to the number of logs. "
"This might be due to missing records in the logs or a wrong grouping primary key. "
"If only a few records show less than the expected number of logs, you might want to "
"ignore and continue by pressing any key."
)
breakpoint()

aggregated_metrics = groups.agg(
agg_value=("value", "mean"), id_model_list=("id_model", lambda x: list(x))
).reset_index()

# sanity check: per aggregated row we find len(logs) unique id_model
assert all(
len(row[1]["id_model_list"]) == len(logs)
for row in aggregated_metrics[["model_idx", "id_model_list"]].iterrows()
)
agg = aggregated_metrics[["model_idx", "id_model_list"]]
agg["num_models"] = agg["id_model_list"].apply(len)
breaking_rows = agg[agg["num_models"] != len(logs)]
if breaking_rows.shape[0] > 0:
logging.warning(f"\n{breaking_rows}")
logging.warning(
"The number of unique id_model in the aggregated metrics is not equal to the number of logs. Please verify."
)
breakpoint()

if DEBUGGING_MODE:
# print(aggregated_metrics[["model_idx", "id_model_list"]])
Expand Down
4 changes: 3 additions & 1 deletion analytics/tools/aggregate_runs/dir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def group_pipelines_by_name(pipeline_logs_directory: Path) -> dict[str, list[Pat
pipeline_logs_directory / d for d in os.listdir(pipeline_logs_directory) if str(d).startswith("pipeline_")
]

pipeline_names: list[tuple[Path, str]] = [(d, (d / ".name").read_text()) for d in pipeline_directories if (d / "pipeline.log").exists()]
pipeline_names: list[tuple[Path, str]] = [
(d, (d / ".name").read_text()) for d in pipeline_directories if (d / "pipeline.log").exists()
]

pipeline_groups = {name: [d for d, n in pipeline_names if n == name] for name in set(n for _, n in pipeline_names)}
return pipeline_groups
Expand Down
Loading