Skip to content

Commit 0c37aac

Browse files
hvarfnerfacebook-github-bot
authored andcommitted
Fixed condition_on_observations in fully Bayesian models (#2151)
Summary: ## Motivation Conditioning on observations in fully bayesian models - enables fully Bayesian JES & KG(?). ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes. Pull Request resolved: #2151 Test Plan: Tests are written to ensure functionality for inferred and fixed noise. __note that the `_aug_batch_shape` attribute assignment was removed in `condition_on_observations`.__ In `FullyBayesianGPs`, this argument could not be assigned (hence the removal). I could not find the use for this argument, and all tests passed when removing it. Other changes are commented throughout, and the changes were made so as to assure that FBGPs can have one set of training data throughout. Howver, conditioning on obervations adds a batch dim to the training data (which is necessary in GPyTorch [here](https://github.com/cornellius-gp/gpytorch/blob/58c033564d28a5537397bc464827783313534e56/gpytorch/models/exact_gp.py#L176)) to infer the correct batch dim. Reviewed By: dme65 Differential Revision: D52256296 Pulled By: saitcakmak fbshipit-source-id: e340897d76e02c32ef7a981bef8a77c49e030ad1
1 parent 967535f commit 0c37aac

File tree

3 files changed

+158
-6
lines changed

3 files changed

+158
-6
lines changed

botorch/models/fully_bayesian.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,8 @@ def forward(self, X: Tensor) -> MultivariateNormal:
498498
rest of this method will not run.
499499
"""
500500
self._check_if_fitted()
501-
x = X.unsqueeze(MCMC_DIM)
502-
mean_x = self.mean_module(x)
503-
covar_x = self.covar_module(x)
501+
mean_x = self.mean_module(X)
502+
covar_x = self.covar_module(X)
504503
return MultivariateNormal(mean_x, covar_x)
505504

506505
# pyre-ignore[14]: Inconsistent override
@@ -534,11 +533,45 @@ def posterior(
534533
"""
535534
self._check_if_fitted()
536535
posterior = super().posterior(
537-
X=X,
536+
X=X.unsqueeze(MCMC_DIM),
538537
output_indices=output_indices,
539538
observation_noise=observation_noise,
540539
posterior_transform=posterior_transform,
541540
**kwargs,
542541
)
543542
posterior = GaussianMixturePosterior(distribution=posterior.distribution)
544543
return posterior
544+
545+
def condition_on_observations(
546+
self, X: Tensor, Y: Tensor, **kwargs: Any
547+
) -> BatchedMultiOutputGPyTorchModel:
548+
"""Conditions on additional observations for a Fully Bayesian model (either
549+
identical across models or unique per-model).
550+
551+
Args:
552+
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
553+
the dimension of the feature space and `batch_shape` is the number of
554+
sampled models.
555+
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
556+
the dimension of the feature space and `batch_shape` is the number of
557+
sampled models.
558+
559+
Returns:
560+
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
561+
given observations. The returned model has `batch_shape` copies of the
562+
training data in case of identical observations (and `batch_shape`
563+
training datasets otherwise).
564+
"""
565+
if X.ndim == 2 and Y.ndim == 2:
566+
# To avoid an error in GPyTorch when inferring the batch dimension, we add
567+
# the explicit batch shape here. The result is that the conditioned model
568+
# will have 'batch_shape' copies of the training data.
569+
X = X.repeat(self.batch_shape + (1, 1))
570+
Y = Y.repeat(self.batch_shape + (1, 1))
571+
572+
elif X.ndim < Y.ndim:
573+
# We need to duplicate the training data to enable correct batch
574+
# size inference in gpytorch.
575+
X = X.repeat(*(Y.shape[:-2] + (1, 1)))
576+
577+
return super().condition_on_observations(X, Y, **kwargs)

botorch/models/gpytorch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
226226
>>> new_Y = torch.sin(new_X[:, 0]) + torch.cos(new_X[:, 1])
227227
>>> model = model.condition_on_observations(X=new_X, Y=new_Y)
228228
"""
229-
Yvar = kwargs.get("noise", None)
229+
Yvar = kwargs.pop("noise", None)
230+
230231
if hasattr(self, "outcome_transform"):
231232
# pass the transformed data to get_fantasy_model below
232233
# (unless we've already trasnformed if BatchedMultiOutputGPyTorchModel)
@@ -242,6 +243,7 @@ def condition_on_observations(self, X: Tensor, Y: Tensor, **kwargs: Any) -> Mode
242243
kwargs.update({"noise": Yvar.squeeze(-1)})
243244
# get_fantasy_model will properly copy any existing outcome transforms
244245
# (since it deepcopies the original model)
246+
245247
return self.get_fantasy_model(inputs=X, targets=Y, **kwargs)
246248

247249

@@ -492,7 +494,8 @@ def condition_on_observations(
492494
fantasy_model._input_batch_shape = fantasy_model.train_targets.shape[
493495
: (-1 if self._num_outputs == 1 else -2)
494496
]
495-
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
497+
if not self._is_fully_bayesian:
498+
fantasy_model._aug_batch_shape = fantasy_model.train_targets.shape[:-1]
496499
return fantasy_model
497500

498501
def subset_output(self, idcs: List[int]) -> BatchedMultiOutputGPyTorchModel:

test/models/test_fully_bayesian.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,18 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
124124
)
125125
return train_X, train_Y, train_Yvar, test_X
126126

127+
def _get_unnormalized_condition_data(
128+
self, num_models: int, num_cond: int, infer_noise: bool, **tkwargs
129+
):
130+
with torch.random.fork_rng():
131+
torch.manual_seed(0)
132+
cond_X = 5 + 5 * torch.rand(num_models, num_cond, 4, **tkwargs)
133+
cond_Y = 10 + torch.sin(cond_X[..., :1])
134+
cond_Yvar = (
135+
None if infer_noise else 0.1 * torch.ones(cond_Y.shape, **tkwargs)
136+
)
137+
return cond_X, cond_Y, cond_Yvar
138+
127139
def _get_mcmc_samples(
128140
self, num_samples: int, dim: int, infer_noise: bool, **tkwargs
129141
):
@@ -656,6 +668,110 @@ def test_custom_pyro_model(self):
656668
atol=5e-4,
657669
)
658670

671+
def test_condition_on_observation(self):
672+
# The following conditioned data shapes should work (output describes):
673+
# training data shape after cond(batch shape in output is req. in gpytorch)
674+
# X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
675+
# X: n x d, Y: n x d --> num_models x n x d
676+
# X: n x d, Y: num_models x n x d --> num_models x n x d
677+
num_models = 3
678+
num_cond = 2
679+
for infer_noise, dtype in itertools.product(
680+
(True, False), (torch.float, torch.double)
681+
):
682+
tkwargs = {"device": self.device, "dtype": dtype}
683+
train_X, train_Y, train_Yvar, test_X = self._get_unnormalized_data(
684+
infer_noise=infer_noise, **tkwargs
685+
)
686+
num_train, num_dims = train_X.shape
687+
# condition on different observations per model to obtain num_models sets
688+
# of training data
689+
cond_X, cond_Y, cond_Yvar = self._get_unnormalized_condition_data(
690+
num_models=num_models,
691+
num_cond=num_cond,
692+
infer_noise=infer_noise,
693+
**tkwargs
694+
)
695+
model = SaasFullyBayesianSingleTaskGP(
696+
train_X=train_X,
697+
train_Y=train_Y,
698+
train_Yvar=train_Yvar,
699+
)
700+
mcmc_samples = self._get_mcmc_samples(
701+
num_samples=num_models,
702+
dim=train_X.shape[-1],
703+
infer_noise=infer_noise,
704+
**tkwargs
705+
)
706+
model.load_mcmc_samples(mcmc_samples)
707+
708+
# need to forward pass before conditioning
709+
model.posterior(train_X)
710+
cond_model = model.condition_on_observations(
711+
cond_X, cond_Y, noise=cond_Yvar
712+
)
713+
posterior = cond_model.posterior(test_X)
714+
self.assertEqual(
715+
posterior.mean.shape, torch.Size([num_models, len(test_X), 1])
716+
)
717+
718+
# since the data is not equal for the conditioned points, a batch size
719+
# is added to the training data
720+
self.assertEqual(
721+
cond_model.train_inputs[0].shape,
722+
torch.Size([num_models, num_train + num_cond, num_dims]),
723+
)
724+
725+
# the batch shape of the condition model is added during conditioning
726+
self.assertEqual(cond_model.batch_shape, torch.Size([num_models]))
727+
728+
# condition on identical sets of data (i.e. one set) for all models
729+
# i.e, with no batch shape. This infers the batch shape.
730+
cond_X_nobatch, cond_Y_nobatch = cond_X[0], cond_Y[0]
731+
model = SaasFullyBayesianSingleTaskGP(
732+
train_X=train_X,
733+
train_Y=train_Y,
734+
train_Yvar=train_Yvar,
735+
)
736+
mcmc_samples = self._get_mcmc_samples(
737+
num_samples=num_models,
738+
dim=train_X.shape[-1],
739+
infer_noise=infer_noise,
740+
**tkwargs
741+
)
742+
model.load_mcmc_samples(mcmc_samples)
743+
744+
# conditioning without a batch size - the resulting conditioned model
745+
# will still have a batch size
746+
model.posterior(train_X)
747+
cond_model = model.condition_on_observations(
748+
cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
749+
)
750+
self.assertEqual(
751+
cond_model.train_inputs[0].shape,
752+
torch.Size([num_models, num_train + num_cond, num_dims]),
753+
)
754+
755+
# test repeated conditining
756+
repeat_cond_X = cond_X + 5
757+
repeat_cond_model = cond_model.condition_on_observations(
758+
repeat_cond_X, cond_Y, noise=cond_Yvar
759+
)
760+
self.assertEqual(
761+
repeat_cond_model.train_inputs[0].shape,
762+
torch.Size([num_models, num_train + 2 * num_cond, num_dims]),
763+
)
764+
765+
# test repeated conditioning without a batch size
766+
repeat_cond_X_nobatch = cond_X_nobatch + 10
767+
repeat_cond_model2 = repeat_cond_model.condition_on_observations(
768+
repeat_cond_X_nobatch, cond_Y_nobatch, noise=cond_Yvar
769+
)
770+
self.assertEqual(
771+
repeat_cond_model2.train_inputs[0].shape,
772+
torch.Size([num_models, num_train + 3 * num_cond, num_dims]),
773+
)
774+
659775
def test_bisect(self):
660776
def f(x):
661777
return 1 + x

0 commit comments

Comments
 (0)