Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ax/adapter/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def view_defaults(self) -> tuple[dict[str, Any], dict[str, Any]]:
self.GENERATOR_KEY_TO_GENERATOR_SETUP.get(self.value)
)
return (
self._get_model_kwargs(info=model_setup_info),
self._get_bridge_kwargs(info=model_setup_info),
self._get_generator_kwargs(info=model_setup_info),
self._get_adapter_kwargs(info=model_setup_info),
)

def view_kwargs(self) -> tuple[dict[str, Any], dict[str, Any]]:
Expand All @@ -400,7 +400,7 @@ def view_kwargs(self) -> tuple[dict[str, Any], dict[str, Any]]:
)

@staticmethod
def _get_model_kwargs(
def _get_generator_kwargs(
info: GeneratorSetup, kwargs: dict[str, Any] | None = None
) -> dict[str, Any]:
return consolidate_kwargs(
Expand All @@ -409,7 +409,7 @@ def _get_model_kwargs(
)

@staticmethod
def _get_bridge_kwargs(
def _get_adapter_kwargs(
info: GeneratorSetup, kwargs: dict[str, Any] | None = None
) -> dict[str, Any]:
return consolidate_kwargs(
Expand Down Expand Up @@ -442,7 +442,7 @@ class Generators(GeneratorRegistryBase):
with a `SobolGenerator(scramble=False)` underlying model.

NOTE: If you deprecate a model, please add its replacement to
`ax.storage.json_store.decoder._DEPRECATED_MODEL_TO_REPLACEMENT` to ensure
`ax.storage.json_store.decoder._DEPRECATED_GENERATOR_TO_REPLACEMENT` to ensure
backwards compatibility of the storage layer.
"""

Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


class TestAdapterFactorySingleObjective(TestCase):
def test_model_kwargs(self) -> None:
def test_generator_kwargs(self) -> None:
"""Tests that model kwargs are passed correctly."""
exp = get_branin_experiment()
sobol = get_sobol(
Expand Down
2 changes: 1 addition & 1 deletion ax/adapter/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_SAASBO(self) -> None:
SaasFullyBayesianSingleTaskGP,
)

def test_enum_model_kwargs(self) -> None:
def test_enum_generator_kwargs(self) -> None:
"""Tests that kwargs are passed correctly when instantiating through the
Generators enum."""
exp = get_branin_experiment()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setUp(self) -> None:
)
generator_spec = GeneratorSpec(
generator_enum=Generators.BOTORCH_MODULAR,
model_kwargs={
generator_kwargs={
"surrogate_spec": surrogate_spec,
"botorch_acqf_class": qLogExpectedImprovement,
"transforms": MBM_X_trans + Y_trans,
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def test_prepare_arm_data_out_of_distribution_arm(self) -> None:
)
gen_spec = self.client._generation_strategy._curr.generator_specs[0]
adapter = gen_spec.generator_enum(
experiment=self.client._experiment, **gen_spec.model_kwargs
experiment=self.client._experiment, **gen_spec.generator_kwargs
)
df = prepare_arm_data(
experiment=self.client._experiment,
Expand Down
4 changes: 2 additions & 2 deletions ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,13 +1504,13 @@ def test_configure_generation_strategy_with_simplify(self) -> None:
self.assertFalse(
client._generation_strategy._nodes[2]
.generator_specs[0]
.model_kwargs["acquisition_options"]["prune_irrelevant_parameters"]
.generator_kwargs["acquisition_options"]["prune_irrelevant_parameters"]
)
client.configure_generation_strategy(simplify_parameter_changes=True)
self.assertTrue(
client._generation_strategy._nodes[2]
.generator_specs[0]
.model_kwargs["acquisition_options"]["prune_irrelevant_parameters"]
.generator_kwargs["acquisition_options"]["prune_irrelevant_parameters"]
)

def test_configure_experiment_with_derived_parameter(self) -> None:
Expand Down
6 changes: 3 additions & 3 deletions ax/api/utils/generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _get_sobol_node(
generator_specs=[
GeneratorSpec(
generator_enum=Generators.SOBOL,
model_kwargs={"seed": initialization_random_seed},
generator_kwargs={"seed": initialization_random_seed},
)
],
transition_criteria=transition_criteria,
Expand Down Expand Up @@ -129,7 +129,7 @@ def _get_mbm_node(
generator_specs=[
GeneratorSpec(
generator_enum=Generators.BOTORCH_MODULAR,
model_kwargs={
generator_kwargs={
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
"torch_device": device,
"transform_configs": get_derelativize_config(
Expand Down Expand Up @@ -171,7 +171,7 @@ def choose_generation_strategy(
generator_specs=[
GeneratorSpec(
generator_enum=Generators.SOBOL,
model_kwargs={"seed": struct.initialization_random_seed},
generator_kwargs={"seed": struct.initialization_random_seed},
)
],
)
Expand Down
10 changes: 5 additions & 5 deletions ax/api/utils/tests/test_generation_strategy_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_choose_gs_random_search(self) -> None:
self.assertEqual(len(sobol_node.generator_specs), 1)
sobol_spec = sobol_node.generator_specs[0]
self.assertEqual(sobol_spec.generator_enum, Generators.SOBOL)
self.assertEqual(sobol_spec.model_kwargs, {"seed": None})
self.assertEqual(sobol_spec.generator_kwargs, {"seed": None})
self.assertEqual(sobol_node._transition_criteria, [])
# Make sure it generates.
run_trials_with_gs(
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_choose_gs_fast_with_options(self) -> None:
self.assertEqual(len(sobol_node.generator_specs), 1)
sobol_spec = sobol_node.generator_specs[0]
self.assertEqual(sobol_spec.generator_enum, Generators.SOBOL)
self.assertEqual(sobol_spec.model_kwargs, {"seed": 0})
self.assertEqual(sobol_spec.generator_kwargs, {"seed": 0})
expected_tc = [
MinTrials(
threshold=2,
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_choose_gs_fast_with_options(self) -> None:
self.assertEqual(mbm_spec.generator_enum, Generators.BOTORCH_MODULAR)
expected_ss = SurrogateSpec(model_configs=[ModelConfig(name="MBM defaults")])
self.assertEqual(
mbm_spec.model_kwargs,
mbm_spec.generator_kwargs,
{
"surrogate_spec": expected_ss,
"torch_device": torch.device("cpu"),
Expand Down Expand Up @@ -187,7 +187,7 @@ def test_choose_gs_quality_with_options(self) -> None:
]
)
self.assertEqual(
mbm_spec.model_kwargs,
mbm_spec.generator_kwargs,
{
"surrogate_spec": expected_ss,
"torch_device": torch.device("cpu"),
Expand Down Expand Up @@ -249,6 +249,6 @@ def test_gs_simplify_parameter_changes(self) -> None:
mbm_node = gs._nodes[2]
mbm_spec = mbm_node.generator_specs[0]
self.assertEqual(
mbm_spec.model_kwargs["acquisition_options"],
mbm_spec.generator_kwargs["acquisition_options"],
{"prune_irrelevant_parameters": simplify},
)
38 changes: 19 additions & 19 deletions ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def get_sobol_mbm_generation_strategy(
acquisition_cls: type[AcquisitionFunction] | None = None,
name: str | None = None,
num_sobol_trials: int = 5,
model_kwargs_override: dict[str, Any] | None = None,
model_gen_kwargs: dict[str, Any] | None = None,
generator_kwargs_override: dict[str, Any] | None = None,
generator_gen_kwargs: dict[str, Any] | None = None,
batch_size: int = 1,
) -> GenerationStrategy:
"""Get a `BenchmarkMethod` that uses Sobol followed by MBM.
Expand All @@ -54,9 +54,9 @@ def get_sobol_mbm_generation_strategy(
name: Name that will be attached to the `GenerationStrategy`.
num_sobol_trials: Number of Sobol trials; can refer to the number of
`BatchTrial`s.
model_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside
`model_kwargs`.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
generator_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside
`generator_kwargs`.
generator_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.

Example:
Expand All @@ -69,21 +69,21 @@ def get_sobol_mbm_generation_strategy(
... acquisition_cls=qLogNoisyExpectedImprovement,
... )
"""
model_kwargs: dict[str, Any] = {
generator_kwargs: dict[str, Any] = {
"surrogate_spec": SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=model_cls)]
),
}
if acquisition_cls is not None:
model_kwargs["botorch_acqf_class"] = acquisition_cls
generator_kwargs["botorch_acqf_class"] = acquisition_cls
acqf_name = acqf_name_abbreviations.get(
acquisition_cls.__name__, acquisition_cls.__name__
)
else:
acqf_name = ""

if model_kwargs_override is not None:
model_kwargs.update(model_kwargs_override)
if generator_kwargs_override is not None:
generator_kwargs.update(generator_kwargs_override)

model_name = model_names_abbrevations.get(model_cls.__name__, model_cls.__name__)
# Historically all benchmarks were sequential, so sequential benchmarks
Expand All @@ -106,8 +106,8 @@ def get_sobol_mbm_generation_strategy(
GenerationStep(
generator=Generators.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs=model_kwargs,
model_gen_kwargs=model_gen_kwargs or {},
generator_kwargs=generator_kwargs,
generator_gen_kwargs=generator_gen_kwargs or {},
),
],
)
Expand All @@ -119,8 +119,8 @@ def get_sobol_botorch_modular_acquisition(
acquisition_cls: type[AcquisitionFunction] | None = None,
name: str | None = None,
num_sobol_trials: int = 5,
model_kwargs_override: dict[str, Any] | None = None,
model_gen_kwargs: dict[str, Any] | None = None,
generator_kwargs_override: dict[str, Any] | None = None,
generator_gen_kwargs: dict[str, Any] | None = None,
batch_size: int = 1,
) -> BenchmarkMethod:
"""Get a `BenchmarkMethod` that uses Sobol followed by MBM.
Expand All @@ -133,9 +133,9 @@ def get_sobol_botorch_modular_acquisition(
num_sobol_trials: Number of Sobol trials; if the orchestrator_options
specify to use `BatchTrial`s, then this refers to the number of
`BatchTrial`s.
model_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside
`model_kwargs`.
model_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
generator_kwargs_override: Passed to the MBM BoTorch `GenerationStep` inside
`generator_kwargs`.
generator_gen_kwargs: Passed to the BoTorch `GenerationStep` and ultimately
to the BoTorch `Model`.
batch_size: Passed to the created ``BenchmarkMethod``.

Expand All @@ -154,7 +154,7 @@ def get_sobol_botorch_modular_acquisition(
... model_cls=SingleTaskGP,
... acquisition_cls=qLogNoisyExpectedImprovement,
... batch_size=5,
... model_gen_kwargs={
... generator_gen_kwargs={
... "model_gen_options": {
... "optimizer_kwargs": {"sequential": False}
... }
Expand All @@ -167,8 +167,8 @@ def get_sobol_botorch_modular_acquisition(
acquisition_cls=acquisition_cls,
name=name,
num_sobol_trials=num_sobol_trials,
model_kwargs_override=model_kwargs_override,
model_gen_kwargs=model_gen_kwargs,
generator_kwargs_override=generator_kwargs_override,
generator_gen_kwargs=generator_gen_kwargs,
batch_size=batch_size,
)

Expand Down
6 changes: 3 additions & 3 deletions ax/benchmark/tests/methods/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _test_mbm_acquisition(self, batch_size: int) -> None:
gs = method.generation_strategy
sobol, kg = gs._steps
self.assertEqual(kg.generator, Generators.BOTORCH_MODULAR)
model_kwargs = none_throws(kg.model_kwargs)
self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient)
surrogate_spec = model_kwargs["surrogate_spec"]
generator_kwargs = none_throws(kg.generator_kwargs)
self.assertEqual(generator_kwargs["botorch_acqf_class"], qKnowledgeGradient)
surrogate_spec = generator_kwargs["surrogate_spec"]
self.assertEqual(
surrogate_spec.model_configs[0].botorch_model_class.__name__,
"SingleTaskGP",
Expand Down
4 changes: 2 additions & 2 deletions ax/benchmark/tests/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_batch(self) -> None:
model_cls=SingleTaskGP,
acquisition_cls=qLogNoisyExpectedImprovement,
batch_size=batch_size,
model_gen_kwargs={
generator_gen_kwargs={
"model_gen_options": {
"optimizer_kwargs": {"sequential": sequential}
}
Expand Down Expand Up @@ -937,7 +937,7 @@ def test_replication_with_generation_node(self) -> None:
name="Sobol",
generator_specs=[
GeneratorSpec(
Generators.SOBOL, model_kwargs={"deduplicate": True}
Generators.SOBOL, generator_kwargs={"deduplicate": True}
)
],
)
Expand Down
6 changes: 4 additions & 2 deletions ax/benchmark/tests/test_benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def test_benchmark_method(self) -> None:

# test that `fit_tracking_metrics` has been correctly set to False
for step in method.generation_strategy._steps:
self.assertFalse(none_throws(step.model_kwargs).get("fit_tracking_metrics"))
self.assertFalse(
none_throws(step.generator_kwargs).get("fit_tracking_metrics")
)

method = BenchmarkMethod(generation_strategy=gs)
self.assertEqual(method.name, method.generation_strategy.name)
Expand All @@ -28,7 +30,7 @@ def test_benchmark_method(self) -> None:
method = BenchmarkMethod(name="Sobol10", generation_strategy=gs)
for node in method.generation_strategy._nodes:
self.assertFalse(
none_throws(node.generator_spec_to_gen_from.model_kwargs).get(
none_throws(node.generator_spec_to_gen_from.generator_kwargs).get(
"fit_tracking_metrics"
)
)
Loading