Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions botorch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
)
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP

from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.higher_order_gp import HigherOrderGP

from botorch.models.map_saas import add_saas_prior, AdditiveMapSaasSingleTaskGP
from botorch.models.map_saas import (
add_saas_prior,
AdditiveMapSaasSingleTaskGP,
EnsembleMapSaasGP,
)
from botorch.models.model import ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import KroneckerMultiTaskGP, MultiTaskGP
Expand All @@ -34,6 +36,7 @@
"AffineDeterministicModel",
"AffineFidelityCostModel",
"ApproximateGPyTorchModel",
"EnsembleMapSaasGP",
"SaasFullyBayesianSingleTaskGP",
"SaasFullyBayesianMultiTaskGP",
"GenericDeterministicModel",
Expand Down
152 changes: 146 additions & 6 deletions botorch/models/map_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any

import torch
from botorch.acquisition.objective import PosteriorTransform
from botorch.exceptions import UnsupportedError
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_lognormal_prior,
)
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
from botorch.utils.constraints import LogTransformedInterval
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.constraints import Interval
Expand All @@ -30,12 +33,15 @@
class SaasPriorHelper:
"""Helper class for specifying parameter and setting closures."""

def __init__(self, tau: float | None = None):
def __init__(self, tau: Tensor | float | None = None):
"""Instantiates a new helper object.

Args:
tau: Value of the global shrinkage parameter. If `None`, the tau will be
a free parameter and inferred from the data.
Tau can be a tensor for batched models, like `EnsembleMapSaasGP`,
where each batch has a different sparsity prior. If tau is a tensor,
it must have shape `batch_shape`.
"""
self._tau = torch.as_tensor(tau) if tau is not None else None

Expand Down Expand Up @@ -102,10 +108,8 @@ def tau_prior_setting_closure(self, m: Kernel, value: Tensor) -> None:
"""
lb = m.raw_tau_constraint.lower_bound.to(m.raw_tau)
ub = m.raw_tau_constraint.upper_bound.to(m.raw_tau)
m.raw_tau.data.fill_(
m.raw_tau_constraint.inverse_transform(
value.to(m.raw_tau).clamp(lb + EPS, ub - EPS)
).item()
m.raw_tau.data = m.raw_tau_constraint.inverse_transform(
value.to(m.raw_tau).clamp(lb + EPS, ub - EPS)
)


Expand Down Expand Up @@ -218,7 +222,7 @@ def get_map_saas_model(
)
# NOTE: need to call `to` to set device and dtype before calling `add_saas_prior`,
# since the SAAS prior contains tensors that are not parameters of the model, and
# terefore not automatically moved to the correct device with a `to` call on the
# therefore not automatically moved to the correct device with a `to` call on the
# model.
base_kernel.to(train_X)
add_saas_prior(base_kernel=base_kernel, tau=tau)
Expand Down Expand Up @@ -421,3 +425,139 @@ def __init__(
)
# Make sure that all buffers and parameters have the correct device and dtype
self.to(dtype=train_X.dtype, device=train_X.device)


class EnsembleMapSaasGP(SingleTaskGP):
_is_ensemble = True

def __init__(
self,
train_X: Tensor,
train_Y: Tensor,
train_Yvar: Tensor | None = None,
num_taus: int = 4,
taus: Tensor | None = None,
outcome_transform: OutcomeTransform | _DefaultType | None = DEFAULT,
input_transform: InputTransform | None = None,
) -> None:
"""Instantiates an ``EnsembleMapSaasGP``, which is a batched ensemble of
``SingleTaskGP``s with the Matern-5/2 kernel and a SAAS prior. The model is
intended to be trained with ``ExactMarginalLogLikelihood`` and
``fit_gpytorch_mll``. Under the hood, the model is equivalent to a
multi-output ``BatchedMultiOutputGPyTorchModel``, but it produces a
``MixtureGaussiaPosterior``, which leads to ensembling of the model outputs.

Args:
train_X: An `n x d` tensor of training features.
train_Y: An `n x 1` tensor of training observations.
train_Yvar: An optional `n x 1` tensor of observed measurement noise.
num_taus: The number of taus to use (4 if omitted). Each tau is
a sparsity parameter for the corresponding kernel in the ensemble.
taus: An optional tensor of shape `num_taus` containing the taus to use.
If omitted, the taus are sampled from a HalfCauchy(0.1) distribution.
outcome_transform: An outcome transform that is applied to the
training data during instantiation and to the posterior during
inference (that is, the `Posterior` obtained by calling
`.posterior` on the model will be on the original scale). We use a
`Standardize` transform if no `outcome_transform` is specified.
Pass down `None` to use no outcome transform. Note that `.train()` will
be called on the outcome transform during instantiation of the model.
input_transform: An input transform that is applied in the model's
forward pass.
"""
if taus is None:
taus = HalfCauchy(torch.tensor(0.1)).sample([num_taus]).to(train_X)
elif taus.shape != torch.Size([num_taus]):
raise ValueError(
f"Expected taus to be of shape {[num_taus]}. Got {taus.shape=}."
)
if train_Y.shape[-1] != 1:
raise UnsupportedError(
f"EnsembleMapSAASGP only supports single-output. Got {train_Y.shape=}."
)
if train_X.ndim != 2:
raise UnsupportedError(
f"EnsembleMapSAASGP only supports 2D inputs. Got {train_X.ndim=}."
)
# Add batch dimension for ensemble.
train_X = train_X.repeat(num_taus, 1, 1)
train_Y = train_Y.repeat(num_taus, 1, 1)
if train_Yvar is not None:
train_Yvar = train_Yvar.repeat(num_taus, 1, 1)
# Construct the sub-modules.
if input_transform is not None:
with torch.no_grad():
transformed_X = input_transform(train_X)
ard_num_dims = transformed_X.shape[-1]
else:
ard_num_dims = train_X.shape[-1]
batch_shape = train_X.shape[:-2] # This is torch.Size([num_taus]).
mean_module = get_mean_module_with_normal_prior(batch_shape=batch_shape)
base_kernel = MaternKernel(
nu=2.5, ard_num_dims=ard_num_dims, batch_shape=batch_shape
)
# NOTE: need to call `to` to set device and dtype before calling
# `add_saas_prior`, since the SAAS prior contains tensors that are not
# parameters of the model, and therefore not automatically moved to the
# correct device with a `to` call on the model.
base_kernel.to(train_X)
add_saas_prior(base_kernel=base_kernel, tau=taus)
covar_module = ScaleKernel(
base_kernel=base_kernel,
outputscale_constraint=LogTransformedInterval(1e-2, 1e4, initial_value=10),
batch_shape=batch_shape,
)
if train_Yvar is None:
likelihood = get_gaussian_likelihood_with_gamma_prior(
batch_shape=batch_shape
)
else:
likelihood = None

super().__init__(
train_X=train_X,
train_Y=train_Y,
train_Yvar=train_Yvar,
likelihood=likelihood,
covar_module=covar_module,
mean_module=mean_module,
outcome_transform=outcome_transform,
input_transform=input_transform,
)

def posterior(
self,
X: Tensor,
output_indices: list[int] | None = None,
observation_noise: bool = False,
posterior_transform: PosteriorTransform | None = None,
**kwargs: Any,
) -> GaussianMixturePosterior:
r"""Computes the posterior over model outputs at the provided points.

Args:
X: A `(batch_shape) x q x d`-dim Tensor, where `d` is the dimension
of the feature space and `q` is the number of points considered
jointly.
output_indices: A list of indices, corresponding to the outputs over
which to compute the posterior (if the model is multi-output).
Can be used to speed up computation if only a subset of the
model's outputs are required for optimization. If omitted,
computes the posterior over all model outputs.
observation_noise: If True, add the observation noise from the
likelihood to the posterior. If a Tensor, use it directly as the
observation noise (must be of shape `(batch_shape) x q x m`).
posterior_transform: An optional PosteriorTransform.

Returns:
A `GaussianMixturePosterior` object. Includes observation noise
if specified.
"""
posterior = super().posterior(
X=X.unsqueeze(MCMC_DIM),
output_indices=output_indices,
observation_noise=observation_noise,
posterior_transform=posterior_transform,
**kwargs,
)
return GaussianMixturePosterior(distribution=posterior.distribution)
56 changes: 55 additions & 1 deletion test/models/test_map_saas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from unittest import mock

import torch

from botorch.exceptions import UnsupportedError
from botorch.fit import (
fit_gpytorch_mll,
Expand All @@ -23,14 +22,17 @@
from botorch.models.map_saas import (
add_saas_prior,
AdditiveMapSaasSingleTaskGP,
EnsembleMapSaasGP,
get_additive_map_saas_covar_module,
get_gaussian_likelihood_with_gamma_prior,
get_mean_module_with_normal_prior,
)
from botorch.models.transforms.input import AppendFeatures, FilterFeatures, Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.optim.utils import get_parameters_and_bounds, sample_all_priors
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.test_utils.mock import mock_optimize
from botorch.utils.constraints import LogTransformedInterval
from botorch.utils.testing import BotorchTestCase
from gpytorch.constraints import Interval
Expand Down Expand Up @@ -510,6 +512,58 @@ def test_batch_model_fitting(self) -> None:
atol=1e-3,
)

@mock_optimize
def test_emsemble_map_saas(self) -> None:
train_X, train_Y, test_X = self._get_data()
d = train_X.shape[-1]
num_taus = 8
for with_options in (False, True):
if with_options:
extra_inputs = {
"train_Yvar": 0.1 * torch.rand_like(train_Y),
"taus": torch.rand(num_taus).to(train_X),
"input_transform": Normalize(d=d),
"outcome_transform": None,
}
else:
extra_inputs = {}
model = EnsembleMapSaasGP(
train_X=train_X, train_Y=train_Y, num_taus=num_taus, **extra_inputs
)
sample_all_priors(model) # Checks that the prior is configured correctly.
mll = ExactMarginalLogLikelihood(model=model, likelihood=model.likelihood)
fit_gpytorch_mll(mll)
self.assertIsInstance(model.covar_module, ScaleKernel)
self.assertIsInstance(model.covar_module.base_kernel, MaternKernel)
self.assertEqual(
model.covar_module.base_kernel.lengthscale.shape,
torch.Size([num_taus, 1, d]),
)
self.assertEqual(model.batch_shape, torch.Size([num_taus]))
posterior = model.posterior(test_X)
self.assertIsInstance(posterior, GaussianMixturePosterior)
if with_options:
self.assertIsInstance(model.likelihood, FixedNoiseGaussianLikelihood)
self.assertIsInstance(model.input_transform, Normalize)
self.assertFalse(hasattr(model, "outcome_transform"))
else:
self.assertIsInstance(model.likelihood, GaussianLikelihood)
self.assertIsInstance(model.outcome_transform, Standardize)
self.assertFalse(hasattr(model, "input_transform"))

def test_ensemble_map_saas_validation(self) -> None:
with self.assertRaisesRegex(ValueError, "Expected taus to be of shape"):
EnsembleMapSaasGP(
train_X=torch.rand(5, 3),
train_Y=torch.rand(5, 1),
num_taus=3,
taus=torch.rand(2),
)
with self.assertRaisesRegex(UnsupportedError, "only supports single-output"):
EnsembleMapSaasGP(train_X=torch.rand(5, 3), train_Y=torch.rand(5, 2))
with self.assertRaisesRegex(UnsupportedError, "only supports 2D inputs"):
EnsembleMapSaasGP(train_X=torch.rand(2, 5, 3), train_Y=torch.rand(2, 5, 1))


class TestAdditiveMapSaasSingleTaskGP(BotorchTestCase):
def _get_data_and_model(
Expand Down
Loading