Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from typing import Any, Callable, cast, NamedTuple, Optional

import ax.service.utils.early_stopping as early_stopping_utils
from ax.analysis.analysis import Analysis, AnalysisCard
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
ParallelCoordinatesPlot,
)
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
Expand Down Expand Up @@ -1768,6 +1772,36 @@ def generate_candidates(
)
return new_trials

def compute_analyses(
self, analyses: Optional[Iterable[Analysis]] = None
) -> list[AnalysisCard]:
analyses = analyses if analyses is not None else self._choose_analyses()

results = [
analysis.compute_result(
experiment=self.experiment, generation_strategy=self.generation_strategy
)
for analysis in analyses
]

# TODO Accumulate Es into their own card, perhaps via unwrap_or_else
cards = [result.unwrap() for result in results if result.is_ok()]

self._save_analysis_cards_to_db_if_possible(
analysis_cards=cards,
experiment=self.experiment,
)

return cards

def _choose_analyses(self) -> list[Analysis]:
"""
Choose Analyses to compute based on the Experiment, GenerationStrategy, etc.
"""

# TODO Create a useful heuristic for choosing analyses
return [ParallelCoordinatesPlot()]

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
Expand Down
46 changes: 46 additions & 0 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from unittest.mock import call, Mock, patch, PropertyMock

import pandas as pd
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
ParallelCoordinatesPlot,
)

from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
Expand Down Expand Up @@ -84,6 +87,7 @@
SpecialGenerationStrategy,
)
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.modeling_stubs import get_generation_strategy
from pyre_extensions import none_throws

from sqlalchemy.orm.exc import StaleDataError
Expand Down Expand Up @@ -2502,3 +2506,45 @@ def test_generate_candidates_does_not_generate_if_overconstrained(self) -> None:
1,
str(scheduler.experiment.trials),
)

def test_compute_analyses(self) -> None:
scheduler = Scheduler(
experiment=get_branin_experiment(with_completed_trial=True),
generation_strategy=get_generation_strategy(),
options=SchedulerOptions(
total_trials=0,
tolerated_trial_failure_rate=0.2,
init_seconds_between_polls=10,
),
)

cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 1)
self.assertEqual(cards[0].name, "ParallelCoordinatesPlot(metric_name=None)")

scheduler = Scheduler(
experiment=get_branin_experiment(with_completed_trial=False),
generation_strategy=get_generation_strategy(),
options=SchedulerOptions(
total_trials=0,
tolerated_trial_failure_rate=0.2,
init_seconds_between_polls=10,
),
)

with self.assertLogs(logger="ax.analysis", level="ERROR") as lg:

cards = scheduler.compute_analyses(analyses=[ParallelCoordinatesPlot()])

self.assertEqual(len(cards), 0)
self.assertTrue(
any(
(
"Failed to compute ParallelCoordinatesPlot(metric_name=None): "
"No data found for metric branin"
)
in msg
for msg in lg.output
)
)
38 changes: 37 additions & 1 deletion ax/service/utils/with_db_settings_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import time

from logging import INFO, Logger
from typing import Any, Optional
from typing import Any, Iterable, Optional

from ax.analysis.analysis import AnalysisCard

from ax.core.base_trial import BaseTrial
from ax.core.experiment import Experiment
Expand All @@ -22,6 +24,7 @@
UnsupportedError,
)
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.storage.sqa_store.save import save_analysis_cards
from ax.utils.common.executils import retry_on_exception
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import not_none
Expand Down Expand Up @@ -468,6 +471,21 @@ def _update_experiment_properties_in_db(
return True
return False

def _save_analysis_cards_to_db_if_possible(
self,
experiment: Experiment,
analysis_cards: Iterable[AnalysisCard],
) -> bool:
if self.db_settings_set:
_save_analysis_cards_to_db_if_possible(
experiment=experiment,
analysis_cards=analysis_cards,
config=self.db_settings.encoder.config,
)
return True

return False


# ------------- Utils for storage that assume `DBSettings` are provided --------

Expand Down Expand Up @@ -590,3 +608,21 @@ def _update_experiment_properties_in_db(
experiment_with_updated_properties=experiment_with_updated_properties,
config=sqa_config,
)


@retry_on_exception(
retries=3,
default_return_on_suppression=False,
exception_types=RETRY_EXCEPTION_TYPES,
)
def _save_analysis_cards_to_db_if_possible(
experiment: Experiment,
analysis_cards: Iterable[AnalysisCard],
sqa_config: SQAConfig,
suppress_all_errors: bool, # Used by the decorator.
) -> None:
save_analysis_cards(
experiment=experiment,
analysis_cards=[*analysis_cards],
config=sqa_config,
)