Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def __repr__(self) -> str:

NON_ABANDONED_STATUSES: Set[TrialStatus] = set(TrialStatus) - {TrialStatus.ABANDONED}

STATUSES_EXPECTING_DATA: List[TrialStatus] = [
TrialStatus.RUNNING,
TrialStatus.COMPLETED,
TrialStatus.EARLY_STOPPED,
]


# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def immutable_once_run(func: Callable) -> Callable:
Expand Down
3 changes: 3 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from ax.modelbridge.registry import ModelRegistryBase
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGenCriterion,
MaxGenerationParallelism,
MaxTrials,
MinimumPreferenceOccurances,
Expand Down Expand Up @@ -183,6 +184,7 @@
AndEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
AugmentedBraninMetric: metric_to_dict,
AugmentedHartmann6Metric: metric_to_dict,
AutoTransitionAfterGenCriterion: transition_criterion_to_dict,
BatchTrial: batch_to_dict,
BenchmarkMetric: metric_to_dict,
BoTorchModel: botorch_model_to_dict,
Expand Down Expand Up @@ -290,6 +292,7 @@
"AndEarlyStoppingStrategy": AndEarlyStoppingStrategy,
"AugmentedBraninMetric": AugmentedBraninMetric,
"AugmentedHartmann6Metric": AugmentedHartmann6Metric,
"AutoTransitionAfterGenCriterion": AutoTransitionAfterGenCriterion,
"Arm": Arm,
"AggregatedBenchmarkResult": AggregatedBenchmarkResult,
"BatchTrial": BatchTrial,
Expand Down
4 changes: 4 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_model_selection=True),
),
(
"GenerationStrategy",
partial(sobol_gpei_generation_node_gs, with_auto_transition=True),
),
("GeneratorRun", get_generator_run),
("Hartmann6Metric", get_hartmann_metric),
("HierarchicalSearchSpace", get_hierarchical_search_space),
Expand Down
23 changes: 17 additions & 6 deletions ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.int_to_float import IntToFloat
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGenCriterion,
MaxGenerationParallelism,
MaxTrials,
MinimumPreferenceOccurances,
Expand Down Expand Up @@ -212,6 +213,7 @@ def get_generation_strategy(

def sobol_gpei_generation_node_gs(
with_model_selection: bool = False,
with_auto_transition: bool = False,
) -> GenerationStrategy:
"""Returns a basic SOBOL+MBM GS using GenerationNodes for testing.

Expand Down Expand Up @@ -255,6 +257,7 @@ def sobol_gpei_generation_node_gs(
not_in_statuses=None,
),
]
alt_mbm_criterion = [AutoTransitionAfterGenCriterion(transition_to="MBM_node")]
step_model_kwargs = {"silently_filter_kwargs": True}
sobol_model_spec = ModelSpec(
model_enum=Models.SOBOL,
Expand Down Expand Up @@ -284,12 +287,20 @@ def sobol_gpei_generation_node_gs(
else:
best_model_selector = None

mbm_node = GenerationNode(
node_name="MBM_node",
transition_criteria=mbm_criterion,
model_specs=mbm_model_specs,
best_model_selector=best_model_selector,
)
if with_auto_transition:
mbm_node = GenerationNode(
node_name="MBM_node",
transition_criteria=alt_mbm_criterion,
model_specs=mbm_model_specs,
best_model_selector=best_model_selector,
)
else:
mbm_node = GenerationNode(
node_name="MBM_node",
transition_criteria=mbm_criterion,
model_specs=mbm_model_specs,
best_model_selector=best_model_selector,
)

sobol_mbm_GS_nodes = GenerationStrategy(
name="Sobol+MBM_Nodes",
Expand Down