Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ def __init__(
model_options: Optional[dict[str, Any]] = None,
mll_class: type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
mll_options: Optional[dict[str, Any]] = None,
outcome_transform_classes: Optional[list[type[OutcomeTransform]]] = None,
outcome_transform_classes: Optional[Sequence[type[OutcomeTransform]]] = None,
outcome_transform_options: Optional[dict[str, dict[str, Any]]] = None,
input_transform_classes: Optional[list[type[InputTransform]]] = None,
input_transform_classes: Optional[Sequence[type[InputTransform]]] = None,
input_transform_options: Optional[dict[str, dict[str, Any]]] = None,
covar_module_class: Optional[type[Kernel]] = None,
covar_module_options: Optional[dict[str, Any]] = None,
Expand Down Expand Up @@ -355,7 +355,9 @@ def _set_formatted_inputs(
dataset: SupervisedDataset,
botorch_model_class_args: list[str],
search_space_digest: SearchSpaceDigest,
botorch_model_class: type[Model],
) -> None:
"""Modifies `formatted_model_inputs` in place."""
for input_name, input_class, input_options in inputs:
if input_class is None:
# This is a temporary solution until all BoTorch models use
Expand All @@ -376,7 +378,7 @@ def _set_formatted_inputs(
# to be expanded to a ModelFactory, see D22457664, to accommodate
# different models in the future.
raise UserInputError(
f"The BoTorch model class {self.botorch_model_class} does not "
f"The BoTorch model class {botorch_model_class.__name__} does not "
f"support the input {input_name}."
)
input_options = deepcopy(input_options) or {}
Expand All @@ -385,7 +387,7 @@ def _set_formatted_inputs(
covar_module_with_defaults = covar_module_argparse(
input_class,
dataset=dataset,
botorch_model_class=self.botorch_model_class,
botorch_model_class=botorch_model_class,
**input_options,
)

Expand Down Expand Up @@ -664,6 +666,19 @@ def best_out_of_sample_point(
) -> tuple[Tensor, Tensor]:
"""Finds the best predicted point and the corresponding value of the
appropriate best point acquisition function.

Args:
search_space_digest: A `SearchSpaceDigest`.
torch_opt_config: A `TorchOptConfig`; none-None `fixed_features` is
not supported.
options: Optional. If present, `seed_inner` (default None) and `qmc`
(default True) will be parsed from `options`; any other keys
will be ignored.

Returns:
A two-tuple (`candidate`, `acqf_value`), where `candidate` is a 1d
Tensor of the best predicted point and `acqf_value` is a scalar (0d)
Tensor of the acquisition function value at the best point.
"""
if torch_opt_config.fixed_features:
# When have fixed features, need `FixedFeatureAcquisitionFunction`
Expand All @@ -690,15 +705,15 @@ def best_out_of_sample_point(
torch_opt_config=torch_opt_config,
options=acqf_options,
)
candidates, acqf_values, _ = acqf.optimize(
candidates, acqf_value, _ = acqf.optimize(
n=1,
search_space_digest=search_space_digest,
inequality_constraints=_to_inequality_constraints(
linear_constraints=torch_opt_config.linear_constraints
),
fixed_features=torch_opt_config.fixed_features,
)
return candidates[0], acqf_values[0]
return candidates[0], acqf_value

def pareto_frontier(self) -> tuple[Tensor, Tensor]:
"""For multi-objective optimization, retrieve Pareto frontier instead
Expand Down Expand Up @@ -736,7 +751,7 @@ def _serialize_attributes_as_kwargs(self) -> dict[str, Any]:

def _extract_construct_input_transform_args(
self, search_space_digest: SearchSpaceDigest
) -> tuple[Optional[list[type[InputTransform]]], dict[str, dict[str, Any]]]:
) -> tuple[Optional[Sequence[type[InputTransform]]], dict[str, dict[str, Any]]]:
"""
Extracts input transform classes and input transform options that will
be used in `self._set_formatted_inputs` and ultimately passed to
Expand Down Expand Up @@ -764,7 +779,7 @@ def _extract_construct_input_transform_args(
)
}

submodel_input_transform_classes: list[type[InputTransform]] = [
submodel_input_transform_classes: Sequence[type[InputTransform]] = [
InputPerturbation
]

Expand Down Expand Up @@ -862,6 +877,8 @@ def _submodel_input_constructor_base(
search_space_digest=search_space_digest,
# This is used to check if the arguments are supported.
botorch_model_class_args=botorch_model_class_args,
# Used to raise the appropriate error if arguments are not supported
botorch_model_class=botorch_model_class,
)
return formatted_model_inputs

Expand Down
86 changes: 49 additions & 37 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
from ax.models.torch.botorch_modular.surrogate import _extract_model_kwargs, Surrogate
from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.testing.mock import fast_botorch_optimize
from ax.utils.testing.torch_stubs import get_torch_test_data
from ax.utils.testing.utils import generic_equals
from botorch.acquisition.monte_carlo import qSimpleRegret
from botorch.models import ModelListGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
Expand All @@ -37,12 +35,12 @@
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.transforms.input import InputPerturbation, Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.utils.datasets import SupervisedDataset
from gpytorch.constraints import GreaterThan, Interval
from gpytorch.kernels import Kernel, MaternKernel, RBFKernel, ScaleKernel
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood, GaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood, LeaveOneOutPseudoLikelihood
from pyre_extensions import assert_is_instance
from torch import Tensor
from torch.nn import ModuleList # @manual -- autodeps can't figure it out.

Expand Down Expand Up @@ -150,6 +148,7 @@ def setUp(self) -> None:
self.training_data = [
SupervisedDataset(
X=self.Xs[0],
# Note: using 1d Y does not match the 2d TorchOptConfig
Y=self.Ys[0],
feature_names=self.feature_names,
outcome_names=self.metric_names,
Expand Down Expand Up @@ -186,16 +185,26 @@ def setUp(self) -> None:
)

def _get_surrogate(
self, botorch_model_class: type[Model]
self, botorch_model_class: type[Model], use_outcome_transform: bool = True
) -> tuple[Surrogate, dict[str, Any]]:
if botorch_model_class is SaasFullyBayesianSingleTaskGP:
mll_options = {"jit_compile": True}
else:
mll_options = None

if use_outcome_transform:
outcome_transform_classes = [Standardize]
outcome_transform_options = {"Standardize": {"m": 1}}
else:
outcome_transform_classes = None
outcome_transform_options = None

surrogate = Surrogate(
botorch_model_class=botorch_model_class,
mll_class=self.mll_class,
mll_options=mll_options,
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=outcome_transform_options,
)
surrogate_kwargs = botorch_model_class.construct_inputs(self.training_data[0])
return surrogate, surrogate_kwargs
Expand Down Expand Up @@ -357,7 +366,9 @@ def test_dtype_and_device_properties(self) -> None:
@patch.object(SingleTaskGP, "__init__", return_value=None)
@patch(f"{SURROGATE_PATH}.fit_botorch_model")
def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None:
surrogate, _ = self._get_surrogate(botorch_model_class=SingleTaskGP)
surrogate, _ = self._get_surrogate(
botorch_model_class=SingleTaskGP, use_outcome_transform=False
)
search_space_digest = SearchSpaceDigest(
feature_names=self.feature_names,
bounds=self.bounds,
Expand Down Expand Up @@ -405,7 +416,12 @@ def test_fit_model_reuse(self, mock_fit: Mock, mock_init: Mock) -> None:

def test_construct_model(self) -> None:
for botorch_model_class in (SaasFullyBayesianSingleTaskGP, SingleTaskGP):
surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class)
# Don't use an outcome transform here because the
# botorch_model_class will change to one that is not compatible with
# outcome transforms below
surrogate, _ = self._get_surrogate(
botorch_model_class=botorch_model_class, use_outcome_transform=False
)
with self.assertRaisesRegex(TypeError, "posterior"):
# Base `Model` does not implement `posterior`, so instantiating it here
# will fail.
Expand Down Expand Up @@ -582,25 +598,8 @@ def test_best_in_sample_point(self) -> None:
self.assertTrue(generic_equals(ckwargs[attr], getattr(self, attr)))

@fast_botorch_optimize
@patch(f"{ACQUISITION_PATH}.Acquisition.__init__", return_value=None)
@patch(
f"{ACQUISITION_PATH}.Acquisition.optimize",
return_value=(
torch.tensor([[0.0]]),
torch.tensor([1.0]),
torch.tensor([1.0]),
),
)
@patch(
f"{SURROGATE_PATH}.pick_best_out_of_sample_point_acqf_class",
return_value=(qSimpleRegret, {Keys.SAMPLER: SobolQMCNormalSampler}),
)
def test_best_out_of_sample_point(
self,
mock_best_point_util: Mock,
mock_acqf_optimize: Mock,
mock_acqf_init: Mock,
) -> None:
def test_best_out_of_sample_point(self) -> None:
torch.manual_seed(0)
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class)
surrogate.fit(
Expand All @@ -613,24 +612,34 @@ def test_best_out_of_sample_point(
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
)
torch_opt_config = dataclasses.replace(
self.torch_opt_config,
fixed_features=None,

surrogate, _ = self._get_surrogate(botorch_model_class=botorch_model_class)
surrogate.fit(
datasets=self.training_data,
search_space_digest=self.search_space_digest,
)
torch_opt_config = TorchOptConfig(objective_weights=torch.tensor([1.0]))
candidate, acqf_value = surrogate.best_out_of_sample_point(
search_space_digest=self.search_space_digest,
torch_opt_config=torch_opt_config,
options=self.options,
)
mock_acqf_init.assert_called_with(
surrogates={"self": surrogate},
botorch_acqf_class=qSimpleRegret,
search_space_digest=self.search_space_digest,
torch_opt_config=torch_opt_config,
options={Keys.SAMPLER: SobolQMCNormalSampler},
candidate_in_bounds = all(
((x >= b[0]) & (x <= b[1]) for x, b in zip(candidate, self.bounds))
)
self.assertTrue(candidate_in_bounds)
self.assertEqual(candidate.shape, torch.Size([3]))

# self.training_data has length 1
sample_mean = self.training_data[0].Y.mean().item()
self.assertEqual(acqf_value.shape, torch.Size([]))
# In realistic cases the maximum posterior mean would exceed the
# sample mean (because the data is standardized), but that might not
# be true when using `fast_botorch_optimize`
eps = 1
self.assertGreaterEqual(
acqf_value.item(), assert_is_instance(sample_mean, float) - eps
)
self.assertTrue(torch.equal(candidate, torch.tensor([0.0])))
self.assertTrue(torch.equal(acqf_value, torch.tensor(1.0)))

def test_serialize_attributes_as_kwargs(self) -> None:
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
Expand Down Expand Up @@ -1016,7 +1025,10 @@ def test_with_botorch_transforms(self) -> None:
outcome_transform_classes=[Standardize],
outcome_transform_options={"Standardize": {"m": 1}},
)
with self.assertRaisesRegex(UserInputError, "The BoTorch model class"):
with self.assertRaisesRegex(
UserInputError,
"The BoTorch model class SingleTaskGPWithDifferentConstructor",
):
surrogate.fit(
datasets=self.supervised_training_data,
search_space_digest=SearchSpaceDigest(
Expand Down