Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ax/api/utils/generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def _get_sobol_node(
- If the initialization budget is not specified, it defaults to 5.
- The TC will not block generation if `allow_exceeding_initialization_budget`
is set to True.
- The TC is currently not restricted to any trial statuses and will
count all trials.
- The TC excludes FAILED and ABANDONED trials from the count, so that
more trials can be generated to meet the
`min_observed_initialization_trials` requirement.
- `use_existing_trials_for_initialization` controls whether trials previously
attached to the experiment are counted as part of the initialization budget.
- MinTrials enforcing the minimum number of observed initialization trials.
Expand All @@ -72,6 +73,7 @@ def _get_sobol_node(
block_gen_if_met=(not allow_exceeding_initialization_budget),
block_transition_if_unmet=True,
use_all_trials_in_exp=use_existing_trials_for_initialization,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials( # This represents minimum observed trials requirement.
threshold=min_observed_initialization_trials,
Expand Down
5 changes: 4 additions & 1 deletion ax/api/utils/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ class GenerationStrategyDispatchStruct:
``choose_generation_strategy``. This is an advanced option
and should not be considered a part of the public API.
initialization_budget: The number of trials to use for initialization.
If ``None``, a default budget of 5 trials is used.
If ``None``, a default budget of 5 trials is used. Note that FAILED
and ABANDONED trials are excluded from this count, allowing more
trials to be generated to meet the
`min_observed_initialization_trials` requirement.
initialization_random_seed: The random seed to use with the Sobol generator
that generates the initialization trials.
initialize_with_center: If True, the center of the search space is used as the
Expand Down
52 changes: 52 additions & 0 deletions ax/api/utils/tests/test_generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def test_choose_gs_fast_with_options(self) -> None:
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=False,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
),
MinTrials(
threshold=4,
Expand Down Expand Up @@ -370,3 +371,54 @@ def test_choose_gs_with_custom_botorch_acqf_class(self) -> None:
self.assertEqual(
mbm_spec.generator_kwargs["surrogate_spec"], expected_ss
)

def test_abandoned_and_failed_trials_excluded_from_initialization_budget(
self,
) -> None:
"""Test that FAILED and ABANDONED trials don't count toward init budget."""
struct = GenerationStrategyDispatchStruct(
method="fast",
initialization_budget=5,
allow_exceeding_initialization_budget=False,
)
gs = choose_generation_strategy(struct=struct)

# Verify the first MinTrials criterion excludes FAILED and ABANDONED
sobol_node = gs._nodes[1] # Node 0 is Center
first_tc = assert_is_instance(sobol_node._transition_criteria[0], MinTrials)
self.assertEqual(
first_tc.not_in_statuses, [TrialStatus.FAILED, TrialStatus.ABANDONED]
)
self.assertEqual(first_tc.threshold, 5)
self.assertTrue(first_tc.block_gen_if_met)

# Test the actual behavior: Generate 5 trials, mark 3 as ABANDONED,
# verify that Sobol can still generate more trials
experiment = get_branin_experiment()
gs.experiment = experiment

# Generate 5 initial trials
for _ in range(5):
gr = gs.gen_single_trial(experiment)
trial = experiment.new_trial(generator_run=gr)
trial.mark_running(no_runner_required=True)

# Mark trials 2, 3, 4 as ABANDONED
if trial.index in [2, 3, 4]:
trial.mark_abandoned()
else:
trial.mark_completed()

# Check we have 2 COMPLETED and 3 ABANDONED
self.assertEqual(
len(experiment.trial_indices_by_status[TrialStatus.COMPLETED]), 2
)
self.assertEqual(
len(experiment.trial_indices_by_status[TrialStatus.ABANDONED]), 3
)

# Should still be able to generate from Sobol since only 2 "valid" trials exist
gr = gs.gen_single_trial(experiment)
self.assertIsNotNone(gr)
# Verify it's from Sobol
self.assertEqual(gr._generator_key, "Sobol")