Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions ax/adapter/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.adapter.base import Adapter, unwrap_observation_data
from ax.core.observation import Observation, ObservationData
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UnsupportedError
from ax.utils.common.logger import get_logger
from ax.utils.stats.model_fit_stats import (
coefficient_of_determination,
Expand Down Expand Up @@ -108,12 +109,14 @@ def cross_validate(
]
arm_names = {obs.arm_name for obs in training_data}
n = len(arm_names)
if folds > n:
raise ValueError(f"Training data only has {n} arms, which is less than folds")
elif n == 0:
if n < 2:
raise UnsupportedError(
"Cross validation requires at least two in-design arms in the training "
f"data. Only {n} in-design arms were found."
)
elif folds > n:
raise ValueError(
f"{model.__class__.__name__} has no training data. Either it has been "
"incorrectly initialized or should not be cross validated."
f"Training data only has {n} arms, which is less than {folds} folds."
)
elif folds < 2 and folds != -1:
raise ValueError("Folds must be -1 for LOO, or > 1.")
Expand Down
18 changes: 12 additions & 6 deletions ax/adapter/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.types import ComparisonOp
from ax.exceptions.core import UnsupportedError
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -74,7 +75,7 @@ def setUp(self) -> None:

def test_cross_validate_base(self) -> None:
# Do cross validation
with self.assertRaisesRegex(ValueError, "which is less than folds"):
with self.assertRaisesRegex(ValueError, "which is less than 4 folds"):
cross_validate(model=self.adapter, folds=4)
with self.assertRaisesRegex(ValueError, "Folds must be"):
cross_validate(model=self.adapter, folds=0)
Expand Down Expand Up @@ -206,11 +207,16 @@ def test_selector(obs: Observation) -> bool:
call_kwargs = mock_cv.call_args.kwargs
self.assertTrue(call_kwargs["use_posterior_predictive"])

def test_cross_validate_gives_a_useful_error_for_model_with_no_data(self) -> None:
exp = get_branin_experiment()
sobol = Generators.SOBOL(experiment=exp, search_space=exp.search_space)
with self.assertRaisesRegex(ValueError, "no training data"):
cross_validate(model=sobol)
def test_cross_validate_gives_a_useful_error_for_insufficient_data(self) -> None:
# Sobol with no data and torch with only one point.
exp_empty = get_branin_experiment()
exp = get_branin_experiment(with_completed_trial=True)
for adapter in [
Generators.SOBOL(experiment=exp_empty),
Generators.BOTORCH_MODULAR(experiment=exp),
]:
with self.assertRaisesRegex(UnsupportedError, "at least two in-design"):
cross_validate(model=adapter)

@mock_botorch_optimize
def test_cross_validate_catches_warnings(self) -> None:
Expand Down
18 changes: 14 additions & 4 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,20 @@ def _pick_fitted_adapter_to_gen_from(self) -> GeneratorSpec:
"This can be caused by model fitting errors, which should be "
"diagnosed by following the exception logs produced earlier."
)
best_model = none_throws(self.best_model_selector).best_model(
generator_specs=fitted_specs,
)
return best_model
try:
best_model = none_throws(self.best_model_selector).best_model(
generator_specs=fitted_specs,
)
return best_model
except Exception as e:
logger.warning(
"The `BestModelSelector` raised an error when selecting the best "
"generator. This can happen if the generator ran into issues during "
"computing the relevant diagnostics, such as insufficient training "
"data. Returning the first generator that was successfully fit. "
f"Original error message: {e}."
)
return fitted_specs[0]

# ------------------------- Trial logic helpers. -------------------------
@property
Expand Down
1 change: 0 additions & 1 deletion ax/generation_strategy/generator_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
class GeneratorSpecJSONEncoder(json.JSONEncoder):
"""Generic encoder to avoid JSON errors in GeneratorSpec.__repr__"""

# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def default(self, o: Any) -> str:
return repr(o)

Expand Down
16 changes: 16 additions & 0 deletions ax/generation_strategy/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,19 @@ def test_pick_fitted_adapter_with_fit_errors(self) -> None:
self.assertEqual(
self.model_selection_node.generator_spec_to_gen_from, self.ms_botorch
)

@mock_botorch_optimize
def test_best_model_selection_errors(self) -> None:
# Testing that the errors raised within best model selector are
# gracefully handled. In this case, we'll get an error in CV
# due to insufficient training data.
exp = get_branin_experiment(with_completed_trial=True)
self.model_selection_node._fit(experiment=exp)
# Check that it selected the first generator and logged a warning.
with self.assertLogs(logger=logger) as logs:
self.assertEqual(
self.model_selection_node.generator_spec_to_gen_from, self.ms_mixed
)
self.assertTrue(
any("raised an error when selecting" in str(log) for log in logs)
)