Skip to content

Commit 059549b

Browse files
committed
Selective rollback of fantasize in Fully Bayesian GPs
1 parent f61c430 commit 059549b

File tree

3 files changed

+71
-35
lines changed

3 files changed

+71
-35
lines changed

botorch/models/fully_bayesian.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -549,24 +549,29 @@ def condition_on_observations(
549549
identical across models or unique per-model).
550550
551551
Args:
552-
X: (Tensor): A `(batch_shape) x num_samples x d`-dim Tensor, where `d` is
552+
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
553553
the dimension of the feature space and `batch_shape` is the number of
554-
sampled models.
555-
Y (Tensor): A `(batch_shape) x num_samples x 1`-dim Tensor, where `d` is
554+
sampled models.
555+
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
556556
the dimension of the feature space and `batch_shape` is the number of
557-
sampled models.
557+
sampled models.
558558
559559
Returns:
560-
BatchedMultiOutputGPyTorchModel: _description_
560+
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
561+
given observations. The returned model has `batch_shape` copies of the
562+
training data in case of identical observations (and `batch_shape`
563+
training datasets otherwise).
561564
"""
562-
if X.ndim < 3 or Y.ndim < 3:
563-
# The can either be thrown here or in GPyTorch, when the inference of the
564-
# batch dimension fails since the training data by default does not have
565-
# a batch shape.
566-
raise ValueError(
567-
"Conditioning in fully Bayesian models must contain a batch dimension."
568-
"Add a batch dimension (the leading dim) with length matching the "
569-
"number of hyperparameter sets to the conditioned data."
570-
)
565+
if X.ndim == 2 and Y.ndim == 2:
566+
# To avoid an error in GPyTorch when inferring the batch dimension, we add
567+
# the explicit batch shape here. The result is that the conditioned model
568+
# will have 'batch_shape' copies of the training data.
569+
X = X.repeat(self.batch_shape + (1, 1))
570+
Y = Y.repeat(self.batch_shape + (1, 1))
571+
572+
elif X.ndim < Y.ndim:
573+
# We need to duplicate the training data to enable correct batch
574+
# size inference in gpytorch.
575+
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
571576

572577
return super().condition_on_observations(X, Y, **kwargs)

botorch/models/gpytorch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
223223
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
224224
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
225225
"""
226-
Yvar = kwargs.get("noise", None)
227-
228-
# for fully bayesian models, the keyword argument "noise": None
229-
# throws an error in LinearOperator releted to inferring batch dims
230-
if "noise" in kwargs and kwargs["noise"] is None:
231-
del kwargs["noise"]
226+
Yvar = kwargs.pop("noise", None)
232227

233228
if hasattr(self, "outcome_transform"):
234229
# pass the transformed data to get_fantasy_model below
@@ -496,7 +491,8 @@ def condition_on_observations(
496491
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
497492
: (-1 if self._num_outputs == 1 else -2)
498493
]
499-
494+
if not self._is_fully_bayesian:
495+
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
500496
return fantasy_model
501497

502498
def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:

test/models/test_fully_bayesian.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,15 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
125125
return train_X, train_Y, train_Yvar, test_X
126126

127127
def _get_unnormalized_condition_data(
128-
self, num_models: int, infer_noise: bool, **tkwargs
128+
self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs
129129
):
130130
with torch.random.fork_rng():
131131
torch.manual_seed(0)
132-
cond_X = 5 + 5 * torch.rand(num_models, 2, 4, **tkwargs)
132+
cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs)
133133
cond_Y = 10 + torch.sin(cond_X[..., :1])
134-
cond_Yvar = None if infer_noise else 0.1 * torch.ones(cond_Y.shape)
134+
cond_Yvar = (
135+
None if infer_noise else 0.1 * torch.ones(cond_Y.shape, **tkwargs)
136+
)
135137
return cond_X, cond_Y, cond_Yvar
136138

137139
def _get_mcmc_samples(
@@ -667,11 +669,15 @@ def test_custom_pyro_model(self):
667669
)
668670

669671
def test_condition_on_observation(self):
670-
672+
# The following conditioned data shapes should work (output describes):
673+
# training data shape after cond(batch shape in output is req. in gpytorch)
674+
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
675+
# X: n x d, Y: n x d --> num_models x n x d
676+
# X: n x d, Y: num_models x n x d --> num_models x n x d
671677
num_models = 3
672678
num_cond = 2
673679
for infer_noise, dtype in itertools.product(
674-
(True,), (torch.float, torch.double)
680+
(True, False), (torch.float, torch.double)
675681
):
676682
tkwargs = {"device": self.device, "dtype": dtype}
677683
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
@@ -681,7 +687,10 @@ def test_condition_on_observation(self):
681687
# condition on different observations per model to obtain num_models sets
682688
# of training data
683689
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
684-
num_models=num_models, infer_noise=infer_noise, **tkwargs
690+
num_models=num_models,
691+
num_cond=num_cond,
692+
infer_noise=infer_noise,
693+
**tkwargs
685694
)
686695
model = SaasFullyBayesianSingleTaskGP(
687696
train_X=train_X,
@@ -712,8 +721,12 @@ def test_condition_on_observation(self):
712721
cond_model.train_inputs[0].shape,
713722
torch.Size([num_models, num_train + num_cond, num_dims]),
714723
)
724+
725+
# the batch shape of the condition model is added during conditioning
726+
self.assertEqual(cond_model.batch_shape, torch.Size([num_models]))
727+
715728
# condition on identical sets of data (i.e. one set) for all models
716-
# i.e, with no batch shape. This should not work.
729+
# i.e, with no batch shape. This infers the batch shape.
717730
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
718731
model = SaasFullyBayesianSingleTaskGP(
719732
train_X=train_X,
@@ -728,14 +741,36 @@ def test_condition_on_observation(self):
728741
)
729742
model.load_mcmc_samples(mcmc_samples)
730743

731-
# This should __NOT__ work - conditioning must have a batch size for the
732-
# conditioned point and is not supported (the training data by default
733-
# does not have a batch size)
744+
# conditioning without a batch size - the resulting conditioned model
745+
# will still have a batch size
734746
model.posterior(train_X)
735-
with self.assertRaises(ValueError):
736-
model.condition_on_observations(
737-
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
738-
)
747+
cond_model = model.condition_on_observations(
748+
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
749+
)
750+
self.assertEqual(
751+
cond_model.train_inputs[0].shape,
752+
torch.Size([num_models, num_train + num_cond, num_dims]),
753+
)
754+
755+
# test repeated conditining
756+
repeat_cond_X = cond_X + 5
757+
repeat_cond_model = cond_model.condition_on_observations(
758+
repeat_cond_X, cond_Y, noise=cond_Yvar
759+
)
760+
self.assertEqual(
761+
repeat_cond_model.train_inputs[0].shape,
762+
torch.Size([num_models, num_train + 2 * num_cond, num_dims]),
763+
)
764+
765+
# test repeated conditioning without a batch size
766+
repeat_cond_X_nobatch = cond_X_nobatch + 10
767+
repeat_cond_model2 = repeat_cond_model.condition_on_observations(
768+
repeat_cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
769+
)
770+
self.assertEqual(
771+
repeat_cond_model2.train_inputs[0].shape,
772+
torch.Size([num_models, num_train + 3 * num_cond, num_dims]),
773+
)
739774

740775
def test_bisect(self):
741776
def f(x):

0 commit comments

Comments
 (0)