Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions botorch/optim/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,16 @@ def sample_all_priors(model: GPyTorchModel, max_retries: int = 100) -> None:
try:
# Set sample shape, so that the prior samples have the same shape
# as `closure(module)` without having to be repeated.
closure_shape = closure(module).shape
prior_shape = prior._extended_shape()
sample_shape = closure_shape[: -len(prior_shape)]
setting_closure(module, prior.sample(sample_shape=sample_shape))
if prior_shape.numel() == 1:
# For a univariate prior we can sample the size of the closure.
# Otherwise we will sample exactly the same value for all
# lengthscales where we commonly specify a univariate prior.
setting_closure(module, prior.sample(closure(module).shape))
else:
closure_shape = closure(module).shape
sample_shape = closure_shape[: -len(prior_shape)]
setting_closure(module, prior.sample(sample_shape=sample_shape))
break
except NotImplementedError:
warn(
Expand Down
32 changes: 27 additions & 5 deletions test/optim/utils/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings
from copy import deepcopy
from string import ascii_lowercase
from typing import Any, Dict
from unittest.mock import MagicMock, patch

import torch
Expand Down Expand Up @@ -246,19 +247,40 @@ def test_sample_all_priors(self):
with self.assertRaises(RuntimeError):
sample_all_priors(model)

def test_univariate_prior(self) -> None:
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
for batch in (torch.Size([]), torch.Size([2, 2])):
model = SingleTaskGP(
train_X=torch.rand(*batch, 5, 3, **tkwargs),
train_Y=torch.randn(*batch, 5, 1, **tkwargs),
covar_module=RBFKernel(
ard_num_dims=3,
batch_shape=batch,
lengthscale_prior=GammaPrior(6.0, 3.0), # univariate
),
)
original_lengthscales = model.covar_module.lengthscale
self.assertEqual(len(torch.unique(original_lengthscales)), 1)
sample_all_priors(model)
new_lengthscales = model.covar_module.lengthscale
self.assertFalse(torch.allclose(original_lengthscales, new_lengthscales))
# Make sure we sampled different lengthscales (happens with probability 1)
self.assertEqual(len(torch.unique(new_lengthscales)), 3 * batch.numel())

def test_with_multivariate_prior(self) -> None:
# This is modified from https://github.com/pytorch/botorch/issues/780.
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
for batch in (torch.Size([]), torch.Size([3])):
model = SingleTaskGP(
train_X=torch.randn(*batch, 2, 2),
train_Y=torch.randn(*batch, 2, 1),
train_X=torch.rand(*batch, 2, 2, **tkwargs),
train_Y=torch.randn(*batch, 2, 1, **tkwargs),
covar_module=RBFKernel(
ard_num_dims=2,
batch_shape=batch,
lengthscale_prior=NormalPrior(
# Make this almost singular for easy comparison below.
torch.tensor([[1.0, 1.0]]),
torch.tensor(1e-10),
torch.tensor([[1.0, 1.0]], **tkwargs),
torch.tensor(1e-10, **tkwargs),
),
),
)
Expand All @@ -267,4 +289,4 @@ def test_with_multivariate_prior(self) -> None:
sample_all_priors(model)
new_lengthscale = model.covar_module.lengthscale
self.assertFalse(torch.allclose(original_lengthscale, new_lengthscale))
self.assertAllClose(new_lengthscale, torch.ones(*batch, 1, 2))
self.assertAllClose(new_lengthscale, torch.ones(*batch, 1, 2, **tkwargs))