Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 57 additions & 41 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
reshape_and_detach,
SaasPyroModel,
)
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import MaternKernel
from gpytorch.kernels.index_kernel import IndexKernel
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
from gpytorch.means.mean import Mean
Expand Down Expand Up @@ -132,7 +134,7 @@ def sample_task_lengthscale(

def load_mcmc_samples(
self, mcmc_samples: dict[str, Tensor]
) -> tuple[Mean, Kernel, Likelihood, Kernel, Parameter]:
) -> tuple[Mean, Kernel, Likelihood, Kernel]:
r"""Load the MCMC samples into the mean_module, covar_module, and likelihood."""
tkwargs = {"device": self.train_X.device, "dtype": self.train_X.dtype}
num_mcmc_samples = len(mcmc_samples["mean"])
Expand All @@ -142,27 +144,32 @@ def load_mcmc_samples(
mcmc_samples=mcmc_samples
)

task_covar_module = MaternKernel(
latent_covar_module = MaternKernel(
nu=2.5,
ard_num_dims=self.task_rank,
batch_shape=batch_shape,
).to(**tkwargs)
task_covar_module.lengthscale = reshape_and_detach(
target=task_covar_module.lengthscale,
latent_covar_module.lengthscale = reshape_and_detach(
target=latent_covar_module.lengthscale,
new_value=mcmc_samples["task_lengthscale"],
)
latent_features = Parameter(
torch.rand(
batch_shape + torch.Size([self.num_tasks, self.task_rank]),
requires_grad=True,
**tkwargs,
)
latent_features = mcmc_samples["latent_features"]
task_covar = latent_covar_module(latent_features)
task_covar_module = IndexKernel(
num_tasks=self.num_tasks,
rank=self.task_rank,
batch_shape=latent_features.shape[:-2],
)
latent_features = reshape_and_detach(
target=latent_features,
new_value=mcmc_samples["latent_features"],
task_covar_module.covar_factor = Parameter(
task_covar.cholesky().to_dense().detach()
)
return mean_module, covar_module, likelihood, task_covar_module, latent_features

# NOTE: 'var' is implicitly assumed to be zero from the sampling procedure in
# the FBMTGP model but not in the regular MTGP. I dont how if the var parameter
# affects predictions in practice, but setting it to zero is consistent with the
# previous implementation.
task_covar_module.var = torch.zeros_like(task_covar_module.var)
return mean_module, covar_module, likelihood, task_covar_module


class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
Expand Down Expand Up @@ -361,7 +368,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
self.covar_module,
self.likelihood,
self.task_covar_module,
self.latent_features,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)

def posterior(
Expand Down Expand Up @@ -391,30 +397,7 @@ def posterior(

def forward(self, X: Tensor) -> MultivariateNormal:
self._check_if_fitted()
x_basic, task_idcs = self._split_inputs(X)

mean_x = self.mean_module(x_basic)
covar_x = self.covar_module(x_basic)

tsub_idcs = task_idcs.squeeze(-1)
if tsub_idcs.ndim > 1:
tsub_idcs = tsub_idcs.squeeze(-2)
latent_features = self.latent_features[:, tsub_idcs, :]

if X.ndim > 3:
# batch eval mode
# for X (batch_shape x num_samples x q x d), task_idcs[:,i,:,] are the same
# reshape X to (batch_shape x num_samples x q x d)
latent_features = latent_features.permute(
[-i for i in range(X.ndim - 1, 2, -1)]
+ [0]
+ [-i for i in range(2, 0, -1)]
)

# Combine the two in an ICM fashion
covar_i = self.task_covar_module(latent_features)
covar = covar_x.mul(covar_i)
return MultivariateNormal(mean_x, covar)
return super().forward(X)

def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
r"""Custom logic for loading the state dict.
Expand Down Expand Up @@ -456,7 +439,40 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.covar_module,
self.likelihood,
self.task_covar_module,
self.latent_features,
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
# Load the actual samples from the state dict
super().load_state_dict(state_dict=state_dict, strict=strict)

def condition_on_observations(
self, X: Tensor, Y: Tensor, **kwargs: Any
) -> BatchedMultiOutputGPyTorchModel:
"""Conditions on additional observations for a Fully Bayesian model (either
identical across models or unique per-model).

Args:
X: A `batch_shape x num_samples x d`-dim Tensor, where `d` is
the dimension of the feature space and `batch_shape` is the number of
sampled models.
Y: A `batch_shape x num_samples x 1`-dim Tensor, where `d` is
the dimension of the feature space and `batch_shape` is the number of
sampled models.

Returns:
BatchedMultiOutputGPyTorchModel: A fully bayesian model conditioned on
given observations. The returned model has `batch_shape` copies of the
training data in case of identical observations (and `batch_shape`
training datasets otherwise).
"""
if X.ndim == 2 and Y.ndim == 2:
# To avoid an error in GPyTorch when inferring the batch dimension, we add
# the explicit batch shape here. The result is that the conditioned model
# will have 'batch_shape' copies of the training data.
X = X.repeat(self.batch_shape + (1, 1))
Y = Y.repeat(self.batch_shape + (1, 1))

elif X.ndim < Y.ndim:
# We need to duplicate the training data to enable correct batch
# size inference in gpytorch.
X = X.repeat(*(Y.shape[:-2] + (1, 1)))

return super().condition_on_observations(X, Y, **kwargs)
2 changes: 0 additions & 2 deletions botorch/models/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,6 @@ def _apply_noise(
self,
X: Tensor,
mvn: MultivariateNormal,
num_outputs: int,
observation_noise: bool | Tensor,
) -> MultivariateNormal:
"""Adds the observation noise to the posterior.
Expand Down Expand Up @@ -948,7 +947,6 @@ def posterior(
mvn = self._apply_noise(
X=X_full,
mvn=mvn,
num_outputs=num_outputs,
observation_noise=observation_noise,
)
# If single-output, return the posterior of a single-output model
Expand Down
Loading