Skip to content

Commit f61c430

Browse files
committed
Fixed condition_on_observations in fully Bayesian models
1 parent 8f1df5a commit f61c430

File tree

3 files changed

+121
-5
lines changed

3 files changed

+121
-5
lines changed

botorch/models/fully_bayesian.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,8 @@ def forward(self, X: Tensor) -> MultivariateNormal:
498498
rest of this method will not run.
499499
"""
500500
self._check_if_fitted()
501-
x = X.unsqueeze(MCMC_DIM)
502-
mean_x = self.mean_module(x)
503-
covar_x = self.covar_module(x)
501+
mean_x = self.mean_module(X)
502+
covar_x = self.covar_module(X)
504503
return MultivariateNormal(mean_x, covar_x)
505504

506505
# pyre-ignore[14]: Inconsistent override
@@ -534,11 +533,40 @@ def posterior(
534533
"""
535534
self._check_if_fitted()
536535
posterior = super().posterior(
537-
X=X,
536+
X=X.unsqueeze(MCMC_DIM),
538537
output_indices=output_indices,
539538
observation_noise=observation_noise,
540539
posterior_transform=posterior_transform,
541540
**kwargs,
542541
)
543542
posterior = GaussianMixturePosterior(distribution=posterior.distribution)
544543
return posterior
544+
545+
def condition_on_observations(
546+
self, X: Tensor, Y: Tensor, **kwargs: Any
547+
) -> BatchedMultiOutputGPyTorchModel:
548+
"""Conditions on additional observations for a Fully Bayesian model (either
549+
identical across models or unique per-model).
550+
551+
Args:
552+
X: (Tensor): A `(batch_shape) x num_samples x d`-dim Tensor, where `d` is
553+
the dimension of the feature space and `batch_shape` is the number of
554+
sampled models.
555+
Y (Tensor): A `(batch_shape) x num_samples x 1`-dim Tensor, where `d` is
556+
the dimension of the feature space and `batch_shape` is the number of
557+
sampled models.
558+
559+
Returns:
560+
BatchedMultiOutputGPyTorchModel: _description_
561+
"""
562+
if X.ndim < 3 or Y.ndim < 3:
563+
# The can either be thrown here or in GPyTorch, when the inference of the
564+
# batch dimension fails since the training data by default does not have
565+
# a batch shape.
566+
raise ValueError(
567+
"Conditioning in fully Bayesian models must contain a batch dimension."
568+
"Add a batch dimension (the leading dim) with length matching the "
569+
"number of hyperparameter sets to the conditioned data."
570+
)
571+
572+
return super().condition_on_observations(X, Y, **kwargs)

botorch/models/gpytorch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
224224
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
225225
"""
226226
Yvar = kwargs.get("noise", None)
227+
228+
# for fully bayesian models, the keyword argument "noise": None
229+
# throws an error in LinearOperator releted to inferring batch dims
230+
if "noise" in kwargs and kwargs["noise"] is None:
231+
del kwargs["noise"]
232+
227233
if hasattr(self, "outcome_transform"):
228234
# pass the transformed data to get_fantasy_model below
229235
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
@@ -239,6 +245,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
239245
kwargs.update({"noise": Yvar.squeeze(-1)})
240246
# get_fantasy_model will properly copy any existing outcome transforms
241247
# (since it deepcopies the original model)
248+
242249
return self.get_fantasy_model(inputs=X, targets=Y, **kwargs)
243250

244251

@@ -489,7 +496,7 @@ def condition_on_observations(
489496
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
490497
: (-1 if self._num_outputs == 1 else -2)
491498
]
492-
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
499+
493500
return fantasy_model
494501

495502
def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:

test/models/test_fully_bayesian.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,16 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
124124
)
125125
return train_X, train_Y, train_Yvar, test_X
126126

127+
def _get_unnormalized_condition_data(
128+
self, num_models: int, infer_noise: bool, **tkwargs
129+
):
130+
with torch.random.fork_rng():
131+
torch.manual_seed(0)
132+
cond_X = 5 + 5 * torch.rand(num_models, 2, 4, **tkwargs)
133+
cond_Y = 10 + torch.sin(cond_X[..., :1])
134+
cond_Yvar = None if infer_noise else 0.1 * torch.ones(cond_Y.shape)
135+
return cond_X, cond_Y, cond_Yvar
136+
127137
def _get_mcmc_samples(
128138
self, num_samples: int, dim: int, infer_noise: bool, **tkwargs
129139
):
@@ -656,6 +666,77 @@ def test_custom_pyro_model(self):
656666
atol=5e-4,
657667
)
658668

669+
def test_condition_on_observation(self):
670+
671+
num_models = 3
672+
num_cond = 2
673+
for infer_noise, dtype in itertools.product(
674+
(True,), (torch.float, torch.double)
675+
):
676+
tkwargs = {"device": self.device, "dtype": dtype}
677+
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
678+
infer_noise=infer_noise, **tkwargs
679+
)
680+
num_train, num_dims = train_X.shape
681+
# condition on different observations per model to obtain num_models sets
682+
# of training data
683+
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
684+
num_models=num_models, infer_noise=infer_noise, **tkwargs
685+
)
686+
model = SaasFullyBayesianSingleTaskGP(
687+
train_X=train_X,
688+
train_Y=train_Y,
689+
train_Yvar=train_Yvar,
690+
)
691+
mcmc_samples = self._get_mcmc_samples(
692+
num_samples=num_models,
693+
dim=train_X.shape[-1],
694+
infer_noise=infer_noise,
695+
**tkwargs
696+
)
697+
model.load_mcmc_samples(mcmc_samples)
698+
699+
# need to forward pass before conditioning
700+
model.posterior(train_X)
701+
cond_model = model.condition_on_observations(
702+
cond_X, cond_Y, noise=cond_Yvar
703+
)
704+
posterior = cond_model.posterior(test_X)
705+
self.assertEqual(
706+
posterior.mean.shape, torch.Size([num_models, len(test_X), 1])
707+
)
708+
709+
# since the data is not equal for the conditioned points, a batch size
710+
# is added to the training data
711+
self.assertEqual(
712+
cond_model.train_inputs[0].shape,
713+
torch.Size([num_models, num_train + num_cond, num_dims]),
714+
)
715+
# condition on identical sets of data (i.e. one set) for all models
716+
# i.e, with no batch shape. This should not work.
717+
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
718+
model = SaasFullyBayesianSingleTaskGP(
719+
train_X=train_X,
720+
train_Y=train_Y,
721+
train_Yvar=train_Yvar,
722+
)
723+
mcmc_samples = self._get_mcmc_samples(
724+
num_samples=num_models,
725+
dim=train_X.shape[-1],
726+
infer_noise=infer_noise,
727+
**tkwargs
728+
)
729+
model.load_mcmc_samples(mcmc_samples)
730+
731+
# This should __NOT__ work - conditioning must have a batch size for the
732+
# conditioned point and is not supported (the training data by default
733+
# does not have a batch size)
734+
model.posterior(train_X)
735+
with self.assertRaises(ValueError):
736+
model.condition_on_observations(
737+
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
738+
)
739+
659740
def test_bisect(self):
660741
def f(x):
661742
return 1 + x

0 commit comments

Comments
 (0)