Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,42 +1101,43 @@ def __new__(
# is set in `GenerationStrategy` constructor, because only then is the order
# of the generation steps actually known.
transition_criteria: list[TransitionCriterion] = []
# Placeholder - will be overwritten in _validate_and_set_step_sequence in GS
placeholder_transition_to = f"GenerationStep_{str(index)}"

if num_trials != -1:
transition_criteria.append(
MinTrials(
threshold=num_trials,
transition_to=placeholder_transition_to,
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
block_gen_if_met=enforce_num_trials,
block_transition_if_unmet=True,
use_all_trials_in_exp=use_all_trials_in_exp,
transition_to=None, # Re-set in GS constructor.
)
)

if min_trials_observed > 0:
transition_criteria.append(
MinTrials(
threshold=min_trials_observed,
transition_to=placeholder_transition_to,
only_in_statuses=[
TrialStatus.COMPLETED,
TrialStatus.EARLY_STOPPED,
],
threshold=min_trials_observed,
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=use_all_trials_in_exp,
transition_to=None, # Re-set in GS constructor.
)
)
if max_parallelism is not None:
transition_criteria.append(
MaxGenerationParallelism(
threshold=max_parallelism,
transition_to=placeholder_transition_to,
only_in_statuses=[TrialStatus.RUNNING],
block_gen_if_met=True,
block_transition_if_unmet=False,
# MaxParallelism transitions to self,
# this will be confirmed in GS init
transition_to=f"GenerationStep_{str(index)}",
)
)

Expand Down
18 changes: 15 additions & 3 deletions ax/generation_strategy/tests/test_aepsych_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ class TestAEPsychCriterion(TestCase):
"""

def test_single_criterion(self) -> None:
criterion = MinimumPreferenceOccurances(metric_signature="m1", threshold=3)
criterion = MinimumPreferenceOccurances(
metric_signature="m1",
threshold=3,
transition_to="next_node", # overwritten during GS init
)

experiment = get_experiment()

Expand Down Expand Up @@ -91,8 +95,16 @@ def test_single_criterion(self) -> None:

def test_many_criteria(self) -> None:
criteria = [
MinimumPreferenceOccurances(metric_signature="m1", threshold=3),
MinTrials(only_in_statuses=[TrialStatus.COMPLETED], threshold=5),
MinimumPreferenceOccurances(
metric_signature="m1",
threshold=3,
transition_to="next_node", # overwritten during GS init
),
MinTrials(
only_in_statuses=[TrialStatus.COMPLETED],
threshold=5,
transition_to="next_node", # overwritten during GS init
),
]

experiment = get_experiment()
Expand Down
13 changes: 10 additions & 3 deletions ax/generation_strategy/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,11 @@ def test_node_string_representation(self) -> None:
self.mbm_generator_spec,
],
transition_criteria=[
MinTrials(threshold=5, only_in_statuses=[TrialStatus.RUNNING])
MinTrials(
threshold=5,
transition_to="next_node",
only_in_statuses=[TrialStatus.RUNNING],
)
],
)
string_rep = str(node)
Expand All @@ -331,7 +335,7 @@ def test_node_string_representation(self) -> None:
"GenerationNode(name='test', "
"generator_specs=[GeneratorSpec(generator_enum=BoTorch, "
"generator_key_override=None)], "
"transition_criteria=[MinTrials(transition_to='None')])",
"transition_criteria=[MinTrials(transition_to='next_node')])",
)

def test_single_fixed_features(self) -> None:
Expand Down Expand Up @@ -439,6 +443,7 @@ def test_init(self) -> None:
[
MinTrials(
threshold=5,
transition_to="GenerationStep_-1", # overwritten during GS init
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
block_gen_if_met=True,
block_transition_if_unmet=True,
Expand All @@ -464,17 +469,19 @@ def test_init(self) -> None:
[
MinTrials(
threshold=5,
transition_to="GenerationStep_-1", # overwritten during GS init
not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED],
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=True,
),
MinTrials(
threshold=3,
transition_to="GenerationStep_-1", # overwritten during GS init
only_in_statuses=[
TrialStatus.COMPLETED,
TrialStatus.EARLY_STOPPED,
],
threshold=3,
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=True,
Expand Down
25 changes: 16 additions & 9 deletions ax/generation_strategy/tests/test_transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def setUp(self) -> None:
self.branin_experiment = get_branin_experiment()

def test_minimum_preference_criterion(self) -> None:
criterion = MinimumPreferenceOccurances(metric_signature="m1", threshold=3)
criterion = MinimumPreferenceOccurances(
metric_signature="m1", threshold=3, transition_to="next_node"
)
experiment = get_experiment()
generation_strategy = GenerationStrategy(
name="SOBOL::default",
Expand Down Expand Up @@ -313,6 +315,7 @@ def test_min_trials_is_met(self) -> None:
# Check mixed status MinTrials
min_criterion = MinTrials(
threshold=3,
transition_to="next_node", # placeholder for testing, transition not used
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
)
self.assertFalse(
Expand Down Expand Up @@ -444,10 +447,13 @@ def test_trials_from_node_empty(self) -> None:
gs.experiment = experiment
max_criterion_with_status = MinTrials(
threshold=2,
transition_to="next_node",
block_gen_if_met=True,
only_in_statuses=[TrialStatus.COMPLETED],
)
max_criterion = MinTrials(threshold=2, block_gen_if_met=True)
max_criterion = MinTrials(
threshold=2, transition_to="next_node", block_gen_if_met=True
)
self.assertFalse(
max_criterion.is_met(experiment=experiment, curr_node=gs._nodes[0])
)
Expand Down Expand Up @@ -478,51 +484,52 @@ def test_repr(self) -> None:
self.maxDiff = None
min_trials_criterion = MinTrials(
threshold=5,
transition_to="GenerationStep_1",
block_gen_if_met=True,
block_transition_if_unmet=False,
transition_to="GenerationStep_1",
only_in_statuses=[TrialStatus.COMPLETED],
not_in_statuses=[TrialStatus.FAILED],
)
self.assertEqual(
str(min_trials_criterion),
"MinTrials({'threshold': 5, "
+ "'transition_to': 'GenerationStep_1', "
+ "'only_in_statuses': [<enum 'TrialStatus'>.COMPLETED], "
+ "'not_in_statuses': [<enum 'TrialStatus'>.FAILED], "
+ "'transition_to': 'GenerationStep_1', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False, "
+ "'count_only_trials_with_data': False})",
)
minimum_trials_in_status_criterion = MinTrials(
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
threshold=0,
transition_to="GenerationStep_2",
only_in_statuses=[TrialStatus.COMPLETED, TrialStatus.EARLY_STOPPED],
block_gen_if_met=True,
block_transition_if_unmet=False,
not_in_statuses=[TrialStatus.FAILED],
)
self.assertEqual(
str(minimum_trials_in_status_criterion),
"MinTrials({'threshold': 0, 'only_in_statuses': "
"MinTrials({'threshold': 0, "
+ "'transition_to': 'GenerationStep_2', "
+ "'only_in_statuses': "
+ "[<enum 'TrialStatus'>.COMPLETED, <enum 'TrialStatus'>.EARLY_STOPPED], "
+ "'not_in_statuses': [<enum 'TrialStatus'>.FAILED], "
+ "'transition_to': 'GenerationStep_2', "
+ "'block_transition_if_unmet': False, "
+ "'block_gen_if_met': True, "
+ "'use_all_trials_in_exp': False, "
+ "'continue_trial_generation': False, "
+ "'count_only_trials_with_data': False})",
)
minimum_preference_occurrences_criterion = MinimumPreferenceOccurances(
metric_signature="m1", threshold=3
metric_signature="m1", threshold=3, transition_to="next_node"
)
self.assertEqual(
str(minimum_preference_occurrences_criterion),
"MinimumPreferenceOccurances({'metric_signature': 'm1', 'threshold': 3, "
+ "'transition_to': None, 'block_gen_if_met': False, "
+ "'transition_to': 'next_node', 'block_gen_if_met': False, "
"'block_transition_if_unmet': True})",
)
max_parallelism = MaxGenerationParallelism(
Expand Down
22 changes: 11 additions & 11 deletions ax/generation_strategy/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TransitionCriterion(SortableBase):

Args:
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
transition to when this criterion is met.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: ``MinTrials`` has not been met yet, but
Expand All @@ -58,11 +58,11 @@ class TransitionCriterion(SortableBase):
different ``GenerationNodes`` by setting this flag to True.
"""

_transition_to: str | None = None
_transition_to: str

def __init__(
self,
transition_to: str | None = None,
transition_to: str,
block_transition_if_unmet: bool | None = True,
block_gen_if_met: bool | None = False,
continue_trial_generation: bool | None = False,
Expand All @@ -73,9 +73,9 @@ def __init__(
self.continue_trial_generation = continue_trial_generation

@property
def transition_to(self) -> str | None:
def transition_to(self) -> str:
"""The name of the next GenerationNode after this TransitionCriterion is
completed, if it exists.
completed.
"""
return self._transition_to

Expand Down Expand Up @@ -264,11 +264,11 @@ class TrialBasedCriterion(TransitionCriterion):
def __init__(
self,
threshold: int,
transition_to: str,
block_transition_if_unmet: bool | None = True,
block_gen_if_met: bool | None = False,
only_in_statuses: list[TrialStatus] | None = None,
not_in_statuses: list[TrialStatus] | None = None,
transition_to: str | None = None,
use_all_trials_in_exp: bool | None = False,
continue_trial_generation: bool | None = False,
count_only_trials_with_data: bool = False,
Expand Down Expand Up @@ -480,7 +480,7 @@ class MinTrials(TrialBasedCriterion):
not_in_statuses: A list of trial statuses to exclude when checking the
criterion threshold.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
transition to when this criterion is met.
block_transition_if_unmet: A flag to prevent the node from completing and
being able to transition to another node. Ex: MaxGenerationParallelism
defaults to setting this to False since we can complete and move on from
Expand All @@ -505,9 +505,9 @@ class MinTrials(TrialBasedCriterion):
def __init__(
self,
threshold: int,
transition_to: str,
only_in_statuses: list[TrialStatus] | None = None,
not_in_statuses: list[TrialStatus] | None = None,
transition_to: str | None = None,
block_transition_if_unmet: bool | None = True,
block_gen_if_met: bool | None = False,
use_all_trials_in_exp: bool | None = False,
Expand All @@ -516,9 +516,9 @@ def __init__(
) -> None:
super().__init__(
threshold=threshold,
transition_to=transition_to,
only_in_statuses=only_in_statuses,
not_in_statuses=not_in_statuses,
transition_to=transition_to,
block_gen_if_met=block_gen_if_met,
block_transition_if_unmet=block_transition_if_unmet,
use_all_trials_in_exp=use_all_trials_in_exp,
Expand Down Expand Up @@ -551,7 +551,7 @@ class MinimumPreferenceOccurances(TransitionCriterion):
threshold: The threshold as an integer for this criterion. Ex: If we want to
generate at most 3 trials, then the threshold is 3.
transition_to: The name of the GenerationNode the GenerationStrategy should
transition to when this criterion is met, if it exists.
transition to when this criterion is met.
block_gen_if_met: A flag to prevent continued generation from the
associated GenerationNode if this criterion is met but other criterion
remain unmet. Ex: ``MinTrials`` has not been met yet, but
Expand All @@ -568,7 +568,7 @@ def __init__(
self,
metric_signature: str,
threshold: int,
transition_to: str | None = None,
transition_to: str,
block_gen_if_met: bool | None = False,
block_transition_if_unmet: bool | None = True,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def get_trial_based_criterion() -> list[TransitionCriterion]:
return [
MinTrials(
threshold=3,
transition_to="next_node",
only_in_statuses=[TrialStatus.RUNNING, TrialStatus.COMPLETED],
not_in_statuses=None,
),
Expand Down