Skip to content

Commit b4040b5

Browse files
committed
Added fantasizing to fully bayesian models, expanded test to include fantazation and repeated conditioning, and allowed conditioning on data without a batch shape (batch shape is inferred for sensible cases)
1 parent f61c430 commit b4040b5

File tree

3 files changed

+121
-35
lines changed

3 files changed

+121
-35
lines changed

botorch/models/fully_bayesian.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import torch
4040
from botorch.acquisition.objective import PosteriorTransform
4141
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
42+
from botorch.models.model import FantasizeMixin
4243
from botorch.models.transforms.input import InputTransform
4344
from botorch.models.transforms.outcome import OutcomeTransform
4445
from botorch.models.utils import validate_input_scaling
@@ -309,7 +310,9 @@ def load_mcmc_samples(
309310
return mean_module, covar_module, likelihood
310311

311312

312-
class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel):
313+
class SaasFullyBayesianSingleTaskGP(
314+
ExactGP, BatchedMultiOutputGPyTorchModel, FantasizeMixin
315+
):
313316
r"""A fully Bayesian single-task GP model with the SAAS prior.
314317
315318
This model assumes that the inputs have been normalized to [0, 1]^d and that
@@ -549,24 +552,30 @@ def condition_on_observations(
549552
identical across models or unique per-model).
550553
551554
Args:
552-
X: (Tensor): A `(batch_shape) x num_samples x d`-dim Tensor, where `d` is
555+
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
553556
the dimension of the feature space and `batch_shape` is the number of
554-
sampled models.
555-
Y (Tensor): A `(batch_shape) x num_samples x 1`-dim Tensor, where `d` is
557+
sampled models.
558+
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
556559
the dimension of the feature space and `batch_shape` is the number of
557-
sampled models.
560+
sampled models.
558561
559562
Returns:
560-
BatchedMultiOutputGPyTorchModel: _description_
563+
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
564+
given observations. The returned model has `batch_shape` copies of the
565+
training data in case of identical observations (and `batch_shape`
566+
training datasets otherwise).
561567
"""
562-
if X.ndim < 3 or Y.ndim < 3:
563-
# The can either be thrown here or in GPyTorch, when the inference of the
564-
# batch dimension fails since the training data by default does not have
565-
# a batch shape.
566-
raise ValueError(
567-
"Conditioning in fully Bayesian models must contain a batch dimension."
568-
"Add a batch dimension (the leading dim) with length matching the "
569-
"number of hyperparameter sets to the conditioned data."
570-
)
568+
if X.ndim == 2 and Y.ndim == 2:
569+
# To avoid an error in GPyTorch when inferring the batch dimension, we add
570+
# the explicit batch shape here. The result is that the conditioned model
571+
# will have 'batch_shape' copies of the training data.
572+
X = X.repeat(self.batch_shape + (1, 1))
573+
Y = Y.repeat(self.batch_shape + (1, 1))
574+
575+
elif X.ndim < Y.ndim:
576+
# this happens when fantasizing - one set of training data and multiple Y.
577+
# We need to duplicate the training data to enable correct batch
578+
# size inference in gpytorch.
579+
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
571580

572581
return super().condition_on_observations(X, Y, **kwargs)

botorch/models/gpytorch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
223223
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
224224
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
225225
"""
226-
Yvar = kwargs.get("noise", None)
227-
228-
# for fully bayesian models, the keyword argument "noise": None
229-
# throws an error in LinearOperator releted to inferring batch dims
230-
if "noise" in kwargs and kwargs["noise"] is None:
231-
del kwargs["noise"]
226+
Yvar = kwargs.pop("noise", None)
232227

233228
if hasattr(self, "outcome_transform"):
234229
# pass the transformed data to get_fantasy_model below
@@ -496,7 +491,8 @@ def condition_on_observations(
496491
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
497492
: (-1 if self._num_outputs == 1 else -2)
498493
]
499-
494+
if not self._is_fully_bayesian:
495+
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
500496
return fantasy_model
501497

502498
def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:

test/models/test_fully_bayesian.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from botorch.models.transforms import Normalize, Standardize
5353
from botorch.posteriors.fully_bayesian import batched_bisect, GaussianMixturePosterior
5454
from botorch.sampling.get_sampler import get_sampler
55+
from botorch.sampling.normal import SobolQMCNormalSampler
5556
from botorch.utils.datasets import SupervisedDataset
5657
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
5758
NondominatedPartitioning,
@@ -125,15 +126,28 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
125126
return train_X, train_Y, train_Yvar, test_X
126127

127128
def _get_unnormalized_condition_data(
128-
self, num_models: int, infer_noise: bool, **tkwargs
129+
self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs
129130
):
130131
with torch.random.fork_rng():
131132
torch.manual_seed(0)
132-
cond_X = 5 + 5 * torch.rand(num_models, 2, 4, **tkwargs)
133+
cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs)
133134
cond_Y = 10 + torch.sin(cond_X[..., :1])
134-
cond_Yvar = None if infer_noise else 0.1 * torch.ones(cond_Y.shape)
135+
cond_Yvar = (
136+
None if infer_noise else 0.1 * torch.ones(cond_Y.shape, **tkwargs)
137+
)
135138
return cond_X, cond_Y, cond_Yvar
136139

140+
def _get_unnormalized_fantasy_data(
141+
self: int, num_cond: int, infer_noise: bool, **tkwargs
142+
):
143+
with torch.random.fork_rng():
144+
torch.manual_seed(0)
145+
fantasy_X = 5 + 5 * torch.rand(num_cond, 4, **tkwargs)
146+
fantasy_Yvar = (
147+
None if infer_noise else 0.1 * torch.ones((num_cond, 1), **tkwargs)
148+
)
149+
return fantasy_X, fantasy_Yvar
150+
137151
def _get_mcmc_samples(
138152
self, num_samples: int, dim: int, infer_noise: bool, **tkwargs
139153
):
@@ -671,7 +685,7 @@ def test_condition_on_observation(self):
671685
num_models = 3
672686
num_cond = 2
673687
for infer_noise, dtype in itertools.product(
674-
(True,), (torch.float, torch.double)
688+
(True, False), (torch.float, torch.double)
675689
):
676690
tkwargs = {"device": self.device, "dtype": dtype}
677691
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
@@ -681,7 +695,10 @@ def test_condition_on_observation(self):
681695
# condition on different observations per model to obtain num_models sets
682696
# of training data
683697
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
684-
num_models=num_models, infer_noise=infer_noise, **tkwargs
698+
num_models=num_models,
699+
num_cond=num_cond,
700+
infer_noise=infer_noise,
701+
**tkwargs
685702
)
686703
model = SaasFullyBayesianSingleTaskGP(
687704
train_X=train_X,
@@ -712,8 +729,12 @@ def test_condition_on_observation(self):
712729
cond_model.train_inputs[0].shape,
713730
torch.Size([num_models, num_train + num_cond, num_dims]),
714731
)
732+
733+
# the batch shape of the condition model is added during conditioning
734+
self.assertEqual(cond_model.batch_shape, torch.Size([num_models]))
735+
715736
# condition on identical sets of data (i.e. one set) for all models
716-
# i.e, with no batch shape. This should not work.
737+
# i.e, with no batch shape. This infers the batch shape.
717738
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
718739
model = SaasFullyBayesianSingleTaskGP(
719740
train_X=train_X,
@@ -728,14 +749,74 @@ def test_condition_on_observation(self):
728749
)
729750
model.load_mcmc_samples(mcmc_samples)
730751

731-
# This should __NOT__ work - conditioning must have a batch size for the
732-
# conditioned point and is not supported (the training data by default
733-
# does not have a batch size)
752+
# conditioning without a batch size - the resulting conditioned model
753+
# will still have a batch size
734754
model.posterior(train_X)
735-
with self.assertRaises(ValueError):
736-
model.condition_on_observations(
737-
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
738-
)
755+
cond_model = model.condition_on_observations(
756+
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
757+
)
758+
self.assertEqual(
759+
cond_model.train_inputs[0].shape,
760+
torch.Size([num_models, num_train + num_cond, num_dims]),
761+
)
762+
763+
# test repeated conditining
764+
repeat_cond_X = cond_X + 5
765+
repeat_cond_model = cond_model.condition_on_observations(
766+
repeat_cond_X, cond_Y, noise=cond_Yvar
767+
)
768+
self.assertEqual(
769+
repeat_cond_model.train_inputs[0].shape,
770+
torch.Size([num_models, num_train + 2 * num_cond, num_dims]),
771+
)
772+
773+
# test repeated conditioning without a batch size
774+
repeat_cond_X_nobatch = cond_X_nobatch + 10
775+
repeat_cond_model2 = repeat_cond_model.condition_on_observations(
776+
repeat_cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
777+
)
778+
self.assertEqual(
779+
repeat_cond_model2.train_inputs[0].shape,
780+
torch.Size([num_models, num_train + 3 * num_cond, num_dims]),
781+
)
782+
783+
def test_fantasize(self):
784+
num_models = 3
785+
fantasy_size = 19
786+
num_cond = 2
787+
for infer_noise, dtype in itertools.product(
788+
(True, False), (torch.float, torch.double)
789+
):
790+
tkwargs = {"device": self.device, "dtype": dtype}
791+
train_X, train_Y, train_Yvar, _ = self._get_unnormalized_data(
792+
infer_noise=infer_noise, **tkwargs
793+
)
794+
num_train, num_dims = train_X.shape
795+
796+
# fantasized X should not have a batch dim
797+
fantasy_X, fantasy_Yvar = self._get_unnormalized_fantasy_data(
798+
infer_noise=infer_noise, num_cond=num_cond, **tkwargs
799+
)
800+
model = SaasFullyBayesianSingleTaskGP(
801+
train_X=train_X,
802+
train_Y=train_Y,
803+
train_Yvar=train_Yvar,
804+
)
805+
mcmc_samples = self._get_mcmc_samples(
806+
num_samples=num_models,
807+
dim=train_X.shape[-1],
808+
infer_noise=infer_noise,
809+
**tkwargs
810+
)
811+
model.load_mcmc_samples(mcmc_samples)
812+
sampler = SobolQMCNormalSampler(torch.Size([fantasy_size]))
813+
fantasy_model = model.fantasize(
814+
fantasy_X, sampler, observation_noise=fantasy_Yvar
815+
)
816+
self.assertEqual(
817+
fantasy_model.train_inputs[0].shape,
818+
torch.Size([fantasy_size, num_models, num_train + num_cond, num_dims]),
819+
)
739820

740821
def test_bisect(self):
741822
def f(x):

0 commit comments

Comments
 (0)