Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 285 additions & 11 deletions ax/adapter/cross_validation.py

Large diffs are not rendered by default.

641 changes: 616 additions & 25 deletions ax/adapter/tests/test_cross_validation.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ax/adapter/tests/test_hierarchical_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _base_test_predict_and_cv(
)
]
)
cv_res = cross_validate(model=mbm)
cv_res = cross_validate(adapter=mbm)
self.assertEqual(len(cv_res), len(experiment.trials))

def test_with_non_hierarchical_hss(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _test_no_early_stopping(self, with_progression: bool) -> None:
self.assertListEqual(dataset.feature_names, ["width", "height"])

# Check that cross validation works.
cross_validate(model=adapter)
cross_validate(adapter=adapter)

def _test_early_stopping(self, complete_with_progression: bool) -> None:
self._simulate(
Expand Down Expand Up @@ -247,7 +247,7 @@ def _test_early_stopping(self, complete_with_progression: bool) -> None:
self.assertEqual(int(candidate_metadata["step"]), 1.0)

# Check that cross validation works.
cross_validate(model=adapter)
cross_validate(adapter=adapter)

def test_no_early_stopping_with_progression(self) -> None:
self._test_no_early_stopping(with_progression=True)
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def compute(

cards = []
cv_results = cross_validate(
model=relevant_adapter, folds=self.folds, untransform=self.untransform
adapter=relevant_adapter, folds=self.folds, untransform=self.untransform
)
relevant_adapter_metric_names = [
relevant_adapter._experiment.signature_to_metric[signature].name
Expand Down
2 changes: 1 addition & 1 deletion ax/generation_strategy/generator_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def cross_validate(

self._assert_fitted()
try:
self._cv_results = cross_validate(model=self.fitted_adapter, **cv_kwargs)
self._cv_results = cross_validate(adapter=self.fitted_adapter, **cv_kwargs)
except NotImplementedError:
warnings.warn(
f"{self.generator_enum.value} cannot be cross validated", stacklevel=2
Expand Down
8 changes: 4 additions & 4 deletions ax/generation_strategy/tests/test_generator_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_cross_validate_with_GP_model(
data=self.experiment.trials[0].fetch_data(),
)
cv_results, cv_diagnostics = ms.cross_validate()
mock_cv.assert_called_with(model=fake_mb, test_key="test-value")
mock_cv.assert_called_with(adapter=fake_mb, test_key="test-value")
mock_diagnostics.assert_called_with(["fake-cv-result"])

self.assertIsNotNone(cv_results)
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_cross_validate_with_GP_model(

self.assertIsNotNone(cv_results)
self.assertIsNotNone(cv_diagnostics)
mock_cv.assert_called_with(model=fake_mb, test_key="test-value")
mock_cv.assert_called_with(adapter=fake_mb, test_key="test-value")
mock_diagnostics.assert_called_with(["fake-cv-result"])

with self.subTest("pass in optional kwargs"):
Expand All @@ -138,7 +138,7 @@ def test_cross_validate_with_GP_model(

self.assertIsNotNone(cv_results)
self.assertIsNotNone(cv_diagnostics)
mock_cv.assert_called_with(model=fake_mb, test_key="test-value", test=1)
mock_cv.assert_called_with(adapter=fake_mb, test_key="test-value", test=1)
self.assertEqual(ms._last_cv_kwargs, {"test": 1, "test_key": "test-value"})

@patch(f"{GeneratorSpec.__module__}.compute_diagnostics")
Expand All @@ -165,7 +165,7 @@ def test_cross_validate_with_non_GP_model(
self.assertIsNone(cv_results)
self.assertIsNone(cv_diagnostics)

mock_cv.assert_called_with(model="fake-adapter", test_key="test-value")
mock_cv.assert_called_with(adapter="fake-adapter", test_key="test-value")
mock_diagnostics.assert_not_called()

def test_fixed_features(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions ax/plot/tests/test_diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def setUp(self) -> None:
super().setUp()
exp = get_branin_experiment(with_batch=True)
exp.trials[0].run()
self.model = Generators.BOTORCH_MODULAR(
self.adapter = Generators.BOTORCH_MODULAR(
# Adapter kwargs
experiment=exp,
data=exp.fetch_data(),
)

def test_cross_validation(self) -> None:
for autoset_axis_limits in [False, True]:
cv = cross_validate(self.model)
cv = cross_validate(adapter=self.adapter)
# Assert that each type of plot can be constructed successfully
label_dict = {"branin": "BrAnIn"}
plot = interact_cross_validation_plotly(
Expand Down
2 changes: 1 addition & 1 deletion ax/service/utils/best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def get_best_parameters_from_model_predictions_with_trial_index(
return _extract_best_arm_from_gr(gr=gr, trials=experiment.trials)

# Check to see if the adapter is worth using.
cv_results = cross_validate(model=adapter)
cv_results = cross_validate(adapter=adapter)
diagnostics = compute_diagnostics(result=cv_results)
assess_model_fit_results = assess_model_fit(diagnostics=diagnostics)

Expand Down
6 changes: 3 additions & 3 deletions ax/service/utils/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
)


def _get_cross_validation_plots(model: Adapter) -> list[go.Figure]:
cv = cross_validate(model=model)
def _get_cross_validation_plots(adapter: Adapter) -> list[go.Figure]:
cv = cross_validate(adapter=adapter)
return [
interact_cross_validation_plotly(
cv_results=cv, caption=CROSS_VALIDATION_CAPTION
Expand Down Expand Up @@ -407,7 +407,7 @@ def get_standard_plots(

try:
logger.debug("Starting cross validation plot.")
output_plot_list.extend(_get_cross_validation_plots(model=model))
output_plot_list.extend(_get_cross_validation_plots(adapter=model))
logger.debug("Finished cross validation plot.")
except Exception as e:
logger.exception(f"Cross-validation plot failed with error: {e}")
Expand Down
Loading