Skip to content

Commit

Permalink
fix: Fix evaluation logic (#532)
Browse files Browse the repository at this point in the history
# Motivation

There were some issues that evaluation results weren't collected into
the logfile.
  • Loading branch information
robinholzi authored Jun 20, 2024
1 parent 55fef7c commit cb0be37
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ def init_from_path(cls, pipeline_logdir: Path) -> "EvaluationExecutor":
eval_state_config = EvalStateConfig.model_validate_json((snapshot_dir / "eval_state.yaml").read_text())
context = pickle.loads((snapshot_dir / "context.pcl").read_bytes())

grpc_handler = GRPCHandler(eval_state_config.config.model_dump(by_alias=True))
grpc_handler.init_cluster_connection()
executor = EvaluationExecutor(
eval_state_config.pipeline_id,
eval_state_config.eval_dir,
eval_state_config.config,
eval_state_config.pipeline,
GRPCHandler(eval_state_config.config.model_dump(by_alias=True)),
grpc_handler,
)
executor.context = context
return executor
Expand Down Expand Up @@ -164,7 +166,7 @@ def run_pipeline_evaluations(
logs = self._launch_evaluations_async(eval_requests, log, eval_status_queue, num_workers)
return logs

def run_post_pipeline_evaluations(self, eval_status_queue: Queue) -> SupervisorLogs:
def run_post_pipeline_evaluations(self, eval_status_queue: Queue, manual_run: bool = False) -> SupervisorLogs:
"""Evaluate the trained models after the core pipeline and store the results."""
if not self.pipeline.evaluation:
return SupervisorLogs(stage_runs=[])
Expand All @@ -182,7 +184,9 @@ def run_post_pipeline_evaluations(self, eval_status_queue: Queue) -> SupervisorL

eval_requests: list[EvalRequest] = []
for eval_handler in self.eval_handlers:
if eval_handler.config.execution_time != "after_pipeline":
if (eval_handler.config.execution_time not in ("after_pipeline", "manual")) or (
eval_handler.config.execution_time == "manual" and not manual_run
):
continue

handler_eval_requests = eval_handler.get_eval_requests_after_pipeline(df_trainings=df_trainings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,6 @@ def _evaluate_and_store_results(
def _done(self, s: ExecutionState, log: StageLog) -> None:
s.pipeline_status_queue.put(pipeline_stage_msg(PipelineStage.DONE, MsgType.GENERAL))
self.logs.pipeline_stages = _pipeline_stage_parents # now includes chronology info
self.logs.materialize(s.log_directory, mode="final")

@pipeline_stage(PipelineStage.POST_EVALUATION_CHECKPOINT, parent=PipelineStage.MAIN, log=False, track=False)
def _post_pipeline_evaluation_checkpoint(self, s: ExecutionState, log: StageLog) -> None:
Expand All @@ -758,7 +757,7 @@ def _post_pipeline_evaluation(self, s: ExecutionState, log: StageLog) -> None:

@pipeline_stage(PipelineStage.EXIT, parent=PipelineStage.MAIN)
def _exit(self, s: ExecutionState, log: StageLog) -> None:
return None # end of pipeline
self.logs.materialize(s.log_directory, mode="final")

# ---------------------------------------------------- Helpers --------------------------------------------------- #

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tempfile import TemporaryDirectory
from typing import Any, Iterator
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest
Expand Down Expand Up @@ -79,8 +79,12 @@ def tracking_df() -> pd.DataFrame:
)


@patch.object(GRPCHandler, "init_cluster_connection", return_value=None)
def test_evaluation_executor_state_management(
evaluation_executor: EvaluationExecutor, tracking_df: pd.DataFrame, tmp_dir_tests: Path
test_init_cluster_connection: MagicMock,
evaluation_executor: EvaluationExecutor,
tracking_df: pd.DataFrame,
tmp_dir_tests: Path,
) -> None:
evaluation_executor.register_tracking_info(
{
Expand All @@ -95,7 +99,9 @@ def test_evaluation_executor_state_management(
assert (tmp_dir_tests / "snapshot" / "eval_state.yaml").exists()
assert (tmp_dir_tests / "snapshot" / "context.pcl").exists()

test_init_cluster_connection.assert_not_called()
loaded_eval_executor = EvaluationExecutor.init_from_path(tmp_dir_tests)
test_init_cluster_connection.assert_called_once()

assert loaded_eval_executor.pipeline_id == evaluation_executor.pipeline_id
assert loaded_eval_executor.pipeline_logdir == evaluation_executor.pipeline_logdir
Expand Down

0 comments on commit cb0be37

Please sign in to comment.