Skip to content

Update HigherOrderGP to use new priors #2555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions botorch/models/higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,20 @@
from botorch.models.utils import gpt_posterior_settings
from botorch.models.utils.assorted import fantasize as fantasize_flag
from botorch.models.utils.gpytorch_modules import (
get_gaussian_likelihood_with_gamma_prior,
get_covar_module_with_dim_scaled_prior,
get_gaussian_likelihood_with_lognormal_prior,
)
from botorch.posteriors import (
GPyTorchPosterior,
HigherOrderGPPosterior,
TransformedPosterior,
)
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import Kernel, MaternKernel
from gpytorch.kernels import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.models import ExactGP
from gpytorch.priors.torch_priors import GammaPrior, MultivariateNormalPrior
from gpytorch.priors.torch_priors import MultivariateNormalPrior
from gpytorch.settings import fast_pred_var, skip_posterior_variances
from linear_operator.operators import (
BatchRepeatLinearOperator,
Expand Down Expand Up @@ -183,7 +185,7 @@ def __init__(
num_latent_dims: Optional[list[int]] = None,
learn_latent_pars: bool = True,
latent_init: str = "default",
outcome_transform: Optional[OutcomeTransform] = None,
outcome_transform: Union[OutcomeTransform, _DefaultType, None] = DEFAULT,
input_transform: Optional[InputTransform] = None,
):
r"""
Expand All @@ -196,7 +198,6 @@ def __init__(
learn_latent_pars: If true, learn the latent parameters.
latent_init: [default or gp] how to initialize the latent parameters.
"""

if input_transform is not None:
input_transform.to(train_X)

Expand All @@ -207,7 +208,11 @@ def __init__(
raise NotImplementedError(
"HigherOrderGP currently only supports 1-dim `batch_shape`."
)

if outcome_transform == DEFAULT:
outcome_transform = FlattenedStandardize(
output_shape=train_Y.shape[-num_output_dims:],
batch_shape=batch_shape,
)
if outcome_transform is not None:
if isinstance(outcome_transform, Standardize) and not isinstance(
outcome_transform, FlattenedStandardize
Expand All @@ -218,6 +223,7 @@ def __init__(
f"{train_Y.shape[- num_output_dims:]} and batch_shape="
f"{batch_shape} instead.",
RuntimeWarning,
stacklevel=2,
)
outcome_transform = FlattenedStandardize(
output_shape=train_Y.shape[-num_output_dims:],
Expand All @@ -232,7 +238,7 @@ def __init__(
self._input_batch_shape = batch_shape

if likelihood is None:
likelihood = get_gaussian_likelihood_with_gamma_prior(
likelihood = get_gaussian_likelihood_with_lognormal_prior(
batch_shape=self._aug_batch_shape
)
else:
Expand All @@ -249,11 +255,9 @@ def __init__(
else:
self.covar_modules = ModuleList(
[
MaternKernel(
nu=2.5,
lengthscale_prior=GammaPrior(3.0, 6.0),
batch_shape=self._aug_batch_shape,
get_covar_module_with_dim_scaled_prior(
ard_num_dims=1 if dim > 0 else train_X.shape[-1],
batch_shape=self._aug_batch_shape,
)
for dim in range(self._num_dimensions)
]
Expand Down
18 changes: 11 additions & 7 deletions botorch/sampling/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _construct_base_samples(self, posterior: Posterior) -> None:
pass # pragma: no cover

def _update_base_samples(
self, posterior: Posterior, base_sampler: NormalMCSampler
self, posterior: Posterior, base_sampler: MCSampler
) -> None:
r"""Update the sampler to use the original base samples for X_baseline.

Expand Down Expand Up @@ -102,7 +102,15 @@ def _update_base_samples(
expanded_samples = current_base_samples.view(view_shape).expand(
expanded_shape
)
if isinstance(posterior, (HigherOrderGPPosterior, MultitaskGPPosterior)):
non_transformed_posterior = (
posterior._posterior
if isinstance(posterior, TransformedPosterior)
else posterior
)
if isinstance(
non_transformed_posterior,
(HigherOrderGPPosterior, MultitaskGPPosterior),
):
n_train_samples = current_base_samples.shape[-1] // 2
# The train base samples.
self.base_samples[..., :n_train_samples] = expanded_samples[
Expand All @@ -113,11 +121,7 @@ def _update_base_samples(
..., -n_train_samples:
]
else:
batch_shape = (
posterior._posterior.batch_shape
if isinstance(posterior, TransformedPosterior)
else posterior.batch_shape
)
batch_shape = non_transformed_posterior.batch_shape
single_output = (
len(posterior.base_sample_shape) - len(batch_shape)
) == 1
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,7 @@ def _test_with_multitask(self, acqf_class: type[AcquisitionFunction]):
def get_acqf(model):
return acqf_class(
model=model,
ref_point=torch.tensor([0.0, 0.0], **tkwargs),
ref_point=torch.tensor([-1.0, -1.0], **tkwargs),
X_baseline=train_x,
sampler=IIDNormalSampler(sample_shape=torch.Size([2])),
objective=hogp_obj if isinstance(model, HigherOrderGP) else None,
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def test_posterior_stddev_batch(self):
acqf = PosteriorStandardDeviation(model=mm)
X = torch.empty(3, 1, 1, device=self.device, dtype=dtype)
pm = acqf(X)
self.assertTrue(torch.equal(pm, std.view(-1)))
self.assertAllClose(pm, std.view(-1))
# check for proper error if multi-output model
mean2 = torch.rand(3, 1, 2, device=self.device, dtype=dtype)
std2 = torch.rand_like(mean2)
Expand Down
3 changes: 2 additions & 1 deletion test/models/test_higher_order_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setUp(self):
train_x = torch.rand(2, 10, 1, device=self.device)
train_y = torch.randn(2, 10, 3, 5, device=self.device)

self.model = HigherOrderGP(train_x, train_y)
self.model = HigherOrderGP(train_x, train_y, outcome_transform=None)

# check that we can assign different kernels and likelihoods
model_2 = HigherOrderGP(
Expand All @@ -48,6 +48,7 @@ def setUp(self):
covar_modules=[RBFKernel(), RBFKernel(), RBFKernel()],
likelihood=GaussianLikelihood(),
)
self.assertIsInstance(model_2.outcome_transform, FlattenedStandardize)

model_3 = HigherOrderGP(
train_X=train_x,
Expand Down
17 changes: 11 additions & 6 deletions test/posteriors/test_higher_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.higher_order_gp import HigherOrderGP
from botorch.posteriors.higher_order import HigherOrderGPPosterior
from botorch.posteriors.transformed import TransformedPosterior
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase

Expand All @@ -22,7 +23,7 @@ def setUp(self):
train_y = torch.randn(2, 10, 3, 5, device=self.device)

m1 = HigherOrderGP(train_x, train_y)
m2 = HigherOrderGP(train_x[0], train_y[0])
m2 = HigherOrderGP(train_x[0], train_y[0], outcome_transform=None)

torch.random.manual_seed(0)
test_x = torch.rand(2, 5, 1, device=self.device)
Expand All @@ -32,18 +33,18 @@ def setUp(self):
posterior3 = m2.posterior(test_x)

self.post_list = [
[m1, test_x, posterior1],
[m2, test_x[0], posterior2],
[m2, test_x, posterior3],
[m1, test_x, posterior1, TransformedPosterior],
[m2, test_x[0], posterior2, HigherOrderGPPosterior],
[m2, test_x, posterior3, HigherOrderGPPosterior],
]

def test_HigherOrderGPPosterior(self):
sample_shaping = torch.Size([5, 3, 5])

for post_collection in self.post_list:
model, test_x, posterior = post_collection
model, test_x, posterior, posterior_class = post_collection

self.assertIsInstance(posterior, HigherOrderGPPosterior)
self.assertIsInstance(posterior, posterior_class)

batch_shape = test_x.shape[:-2]
expected_extended_shape = batch_shape + sample_shaping
Expand Down Expand Up @@ -105,6 +106,10 @@ def test_HigherOrderGPPosterior(self):

model.eval()
eval_mode_variance = model(test_x).variance.reshape_as(posterior_variance)
if hasattr(model, "outcome_transform"):
eval_mode_variance = model.outcome_transform.untransform(
eval_mode_variance, eval_mode_variance
)[1]
self.assertLess(
(posterior_variance - eval_mode_variance).norm()
/ eval_mode_variance.norm(),
Expand Down
Loading