Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from ax.core.map_data import MAP_KEY, MapData
from ax.core.objective import MultiObjective
from ax.core.trial_status import TrialStatus
from ax.early_stopping.utils import _interval_boundary, estimate_early_stopping_savings
from ax.early_stopping.utils import (
_interval_boundary,
align_partial_results,
estimate_early_stopping_savings,
)
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.generation_strategy.generation_node import GenerationNode
from ax.utils.common.base import Base
Expand All @@ -40,6 +44,7 @@ def __init__(
trial_indices_to_ignore: list[int] | None = None,
normalize_progressions: bool = False,
interval: float | None = None,
check_safe: bool = False,
) -> None:
"""A BaseEarlyStoppingStrategy class.

Expand Down Expand Up @@ -75,6 +80,9 @@ def __init__(
eligible on first check. If checked again at progression 18, it's not
eligible (both in interval [10, 20)). Once it reaches progression 21,
it's eligible again (crossed into interval [20, 30)).
check_safe: If True, applies the relevant safety checks to gate
early-stopping when it is likely to be harmful. If False (default),
bypasses the safety check and directly applies early-stopping decisions.
"""
# Validate interval
if interval is not None and not interval > 0:
Expand All @@ -101,6 +109,7 @@ def __init__(
self.trial_indices_to_ignore = trial_indices_to_ignore
self.normalize_progressions = normalize_progressions
self.interval = interval
self.check_safe = check_safe
# Track the last progression value where each trial was checked
self._last_check_progressions: dict[int, float] = {}

Expand Down Expand Up @@ -170,19 +179,18 @@ def should_stop_trials_early(
Returns:
A dictionary mapping trial indices that should be early stopped to
(optional) messages with the associated reason. Returns an empty
dictionary if early stopping would be harmful.
dictionary if early stopping would be harmful (when safety check is
enabled).
"""
return (
self._should_stop_trials_early(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
)
if not self._is_harmful(
trial_indices=trial_indices,
experiment=experiment,
)
else {}
if self.check_safe and self._is_harmful(
trial_indices=trial_indices,
experiment=experiment,
):
return {}
return self._should_stop_trials_early(
trial_indices=trial_indices,
experiment=experiment,
current_node=current_node,
)

def estimate_early_stopping_savings(self, experiment: Experiment) -> float:
Expand All @@ -203,10 +211,10 @@ def estimate_early_stopping_savings(self, experiment: Experiment) -> float:

return estimate_early_stopping_savings(experiment=experiment)

def _check_validity_and_get_data(
def _lookup_and_validate_data(
self, experiment: Experiment, metric_signatures: list[str]
) -> MapData | None:
"""Validity checks and returns the `MapData` used for early stopping that
"""Looks up and validates the `MapData` used for early stopping that
is associated with `metric_signatures`. This function also handles normalizing
progressions.
"""
Expand Down Expand Up @@ -538,6 +546,42 @@ def _all_objectives_and_directions(self, experiment: Experiment) -> dict[str, bo

return directions

def _prepare_aligned_data(
self, experiment: Experiment, metric_signatures: list[str]
) -> tuple[pd.DataFrame, pd.DataFrame] | None:
"""Get raw experiment data and align it for early stopping evaluation.

Args:
experiment: Experiment that contains the trials and other contextual data.
metric_signatures: List of metric signatures to include in the aligned data.

Returns:
A tuple of (long_df, multilevel_wide_df) where:
- long_df: The raw MapData dataframe (long format) before interpolation
- multilevel_wide_df: Hierarchical wide dataframe (indexed by progression)
with first level ["mean", "sem"] and second level metric signatures
Returns None if data cannot be retrieved or aligned.
"""
data = self._lookup_and_validate_data(
experiment=experiment, metric_signatures=metric_signatures
)
if data is None:
return None

try:
multilevel_wide_df = align_partial_results(
df=(long_df := data.map_df),
metrics=metric_signatures,
)
except Exception as e:
logger.warning(
f"Encountered exception while aligning data: {e}. "
"Cannot proceed with early stopping."
)
return None

return long_df, multilevel_wide_df


class ModelBasedEarlyStoppingStrategy(BaseEarlyStoppingStrategy):
"""A base class for model based early stopping strategies. Includes
Expand All @@ -553,6 +597,7 @@ def __init__(
normalize_progressions: bool = False,
min_progression_modeling: float | None = None,
interval: float | None = None,
check_safe: bool = False,
) -> None:
"""A ModelBasedEarlyStoppingStrategy class.

Expand Down Expand Up @@ -600,17 +645,18 @@ def __init__(
trial_indices_to_ignore=trial_indices_to_ignore,
normalize_progressions=normalize_progressions,
interval=interval,
check_safe=check_safe,
)
self.min_progression_modeling = min_progression_modeling

def _check_validity_and_get_data(
def _lookup_and_validate_data(
self, experiment: Experiment, metric_signatures: list[str]
) -> MapData | None:
"""Validity checks and returns the `MapData` used for early stopping that
"""Looks up and validates the `MapData` used for early stopping that
is associated with `metric_signatures`. This function also handles normalizing
progressions.
"""
map_data = super()._check_validity_and_get_data(
map_data = super()._lookup_and_validate_data(
experiment=experiment, metric_signatures=metric_signatures
)
if map_data is not None and self.min_progression_modeling is not None:
Expand Down
57 changes: 23 additions & 34 deletions ax/early_stopping/strategies/percentile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
from ax.core.experiment import Experiment
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
from ax.early_stopping.utils import align_partial_results
from ax.early_stopping.utils import _is_worse
from ax.exceptions.core import UnsupportedError
from ax.generation_strategy.generation_node import GenerationNode
from ax.utils.common.logger import get_logger
Expand All @@ -36,6 +36,7 @@ def __init__(
normalize_progressions: bool = False,
n_best_trials_to_complete: int | None = None,
interval: float | None = None,
check_safe: bool = False,
) -> None:
"""Construct a PercentileEarlyStoppingStrategy instance.

Expand Down Expand Up @@ -76,6 +77,9 @@ def __init__(
trials cross interval boundaries (at min_progression + k * interval,
k = 0, 1, 2...). Prevents premature stopping decisions when the
orchestrator (ex, GAIN) polls frequently.
check_safe: If True, applies the relevant safety checks to gate
early-stopping when it is likely to be harmful. If False (default),
bypasses the safety check and directly applies early-stopping decisions.
"""
super().__init__(
metric_signatures=metric_signatures,
Expand All @@ -85,6 +89,7 @@ def __init__(
min_curves=min_curves,
normalize_progressions=normalize_progressions,
interval=interval,
check_safe=check_safe,
)

self.percentile_threshold = percentile_threshold
Expand Down Expand Up @@ -129,42 +134,28 @@ def _should_stop_trials_early(
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
)
data = self._check_validity_and_get_data(
maybe_aligned_dataframes = self._prepare_aligned_data(
experiment=experiment, metric_signatures=[metric_signature]
)
if data is None:
# don't stop any trials if we don't get data back
if maybe_aligned_dataframes is None:
return {}

df = data.map_df
long_df, multilevel_wide_df = maybe_aligned_dataframes
wide_df = multilevel_wide_df["mean"][metric_signature]

# default checks on `min_progression` and `min_curves`; if not met, don't do
# early stopping at all and return {}
if not self.is_eligible_any(
trial_indices=trial_indices, experiment=experiment, df=df
trial_indices=trial_indices, experiment=experiment, df=long_df
):
return {}

try:
aligned_df = align_partial_results(
df=df,
metrics=[metric_signature],
)
except Exception as e:
logger.warning(
f"Encountered exception while aligning data: {e}. "
"Not early stopping any trials."
)
return {}

metric_to_aligned_means = aligned_df["mean"]
aligned_means = metric_to_aligned_means[metric_signature]
decisions = {
trial_index: self._should_stop_trial_early(
trial_index=trial_index,
experiment=experiment,
df=aligned_means,
df_raw=df,
wide_df=wide_df,
long_df=long_df,
minimize=minimize,
)
for trial_index in trial_indices
Expand All @@ -179,8 +170,8 @@ def _should_stop_trial_early(
self,
trial_index: int,
experiment: Experiment,
df: pd.DataFrame,
df_raw: pd.DataFrame,
wide_df: pd.DataFrame,
long_df: pd.DataFrame,
minimize: bool,
) -> tuple[bool, str | None]:
"""Stop a trial if its performance is in the bottom `percentile_threshold`
Expand All @@ -189,9 +180,9 @@ def _should_stop_trial_early(
Args:
trial_index: Indices of candidate trial to stop early.
experiment: Experiment that contains the trials and other contextual data.
df: Dataframe of partial results after applying interpolation,
filtered to objective metric.
df_raw: The original MapData dataframe (before interpolation).
wide_df: Dataframe of partial results after applying interpolation,
filtered to objective metric (wide format, non-hierarchical).
long_df: The original MapData dataframe (long format, before interpolation).
minimize: Whether objective value is being minimized.

Returns:
Expand All @@ -201,18 +192,18 @@ def _should_stop_trial_early(
"""

stopping_eligible, reason = self.is_eligible(
trial_index=trial_index, experiment=experiment, df=df_raw
trial_index=trial_index, experiment=experiment, df=long_df
)
if not stopping_eligible:
return False, reason

# Extract the metric curve for the trial under consideration
trial_series = df[trial_index]
trial_series = wide_df[trial_index]
# Find the latest progression with a recorded value for this trial
trial_latest_prog = trial_series.last_valid_index()

# Get objective values for all trials at this progression
objective_latest_prog = df.loc[trial_latest_prog]
objective_latest_prog = wide_df.loc[trial_latest_prog]
# Filter to trials that have reached this progression (exclude NaN values)
ref_selector = objective_latest_prog.notna()
ref_objectives_latest_prog = objective_latest_prog[ref_selector]
Expand Down Expand Up @@ -248,10 +239,8 @@ def _should_stop_trial_early(
trial_objective_value = objective_latest_prog[trial_index]
# Determine if this trial should be stopped based on its performance
# relative to the threshold
should_early_stop = (
trial_objective_value > ref_threshold_value
if minimize
else trial_objective_value < ref_threshold_value
should_early_stop = _is_worse(
trial_objective_value, ref_threshold_value, minimize=minimize
)

# Build the percentile threshold message that explains performance
Expand Down
7 changes: 6 additions & 1 deletion ax/early_stopping/strategies/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
min_curves: int | None = 5,
trial_indices_to_ignore: list[int] | None = None,
normalize_progressions: bool = False,
check_safe: bool = False,
) -> None:
"""Construct a ThresholdEarlyStoppingStrategy instance.

Expand Down Expand Up @@ -60,6 +61,9 @@ def __init__(
specified in the transformed space. IMPORTANT: Typically, `min_curves`
should be > 0 to ensure that at least one trial has completed and that
we have a reliable approximation for `prog_max`.
check_safe: If True, applies the relevant safety checks to gate
early-stopping when it is likely to be harmful. If False (default),
bypasses the safety check and directly applies early-stopping decisions.
"""
super().__init__(
metric_signatures=metric_signatures,
Expand All @@ -68,6 +72,7 @@ def __init__(
min_curves=min_curves,
trial_indices_to_ignore=trial_indices_to_ignore,
normalize_progressions=normalize_progressions,
check_safe=check_safe,
)
self.metric_threshold = metric_threshold

Expand Down Expand Up @@ -110,7 +115,7 @@ def _should_stop_trials_early(
metric_signature, minimize = self._default_objective_and_direction(
experiment=experiment
)
data = self._check_validity_and_get_data(
data = self._lookup_and_validate_data(
experiment=experiment, metric_signatures=[metric_signature]
)
if data is None:
Expand Down
Loading