Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 51 additions & 26 deletions botorch/acquisition/bayesian_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@

from typing import Optional

import torch
from botorch.acquisition.acquisition import AcquisitionFunction, MCSamplerMixin
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.acquisition.objective import PosteriorTransform
from botorch.models.fully_bayesian import MCMC_DIM, SaasFullyBayesianSingleTaskGP
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -54,48 +55,72 @@ class qBayesianActiveLearningByDisagreement(
def __init__(
self,
model: SaasFullyBayesianSingleTaskGP,
sampler: Optional[MCSampler] = None,
posterior_transform: Optional[PosteriorTransform] = None,
X_pending: Optional[Tensor] = None,
) -> None:
"""
Batch implementation [kirsch2019batchbald]_ of BALD [Houlsby2011bald]_,
which maximizes the mutual information between the next observation and the
hyperparameters of the model. Computed by informational lower bound.
hyperparameters of the model. Computed by Monte Carlo integration.

Args:
model: A fully bayesian single-outcome model.
X_pending: A `batch_shape, m x d`-dim Tensor of `m` design points.
model: A fully bayesian model (SaasFullyBayesianSingleTaskGP).
sampler: The sampler used for drawing samples to approximate the entropy
of the Gaussian Mixture posterior.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points

"""
super().__init__(model)
super().__init__(model=model)
MCSamplerMixin.__init__(self, sampler=sampler)
self.set_X_pending(X_pending)
self.posterior_transform = posterior_transform

@concatenate_pending_points
@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate qBayesianActiveLearningByDisagreement on the candidate set `X`.
A monte carlo-estimated information gain is computed over a Gaussian Mixture
marginal posterior, and the Gaussian conditional posterior to obtain the
qBayesianActiveLearningByDisagreement on the candidate set `X`.

Args:
X: `batch_shape x q x D`-dim Tensor of input points.

Returns:
A `batch_shape x num_models`-dim Tensor of BALD values.
"""
return self._compute_lower_bound_information_gain(X)

def _compute_lower_bound_information_gain(self, X: Tensor) -> Tensor:
r"""Evaluates the lower bounded information gain on the candidate set `X`.

Args:
X: `batch_shape x q x D`-dim Tensor of input points.

Returns:
A `batch_shape x num_models`-dim Tensor of information gains.
"""
posterior = self.model.posterior(X, observation_noise=True)
marg_covar = posterior.mixture_covariance_matrix
cond_variances = posterior.variance

prev_entropy = torch.logdet(marg_covar).unsqueeze(-1)
# squeeze excess dim and mean over q-batch
post_ub_entropy = torch.log(cond_variances).squeeze(-1).mean(-1)

return prev_entropy - post_ub_entropy
posterior = self.model.posterior(
X, observation_noise=True, posterior_transform=self.posterior_transform
)
# draw samples from the mixture posterior.
# samples: num_samples x batch_shape x num_models x q x num_outputs
samples = self.get_posterior_samples(posterior=posterior)

# Estimate the entropy of 'num_samples' samples from 'num_models' models by
# evaluating the log_prob on each sample on the mixture posterior
# (which constitutes of M models). thus, order N*M^2 computations

# Make room and move the model dim to the front, squeeze the num_outputs dim.
# prev_samples: num_models x num_samples x batch_shape x 1 x q
prev_samples = samples.unsqueeze(0).transpose(0, MCMC_DIM).squeeze(-1)

# avg the probs over models in the mixture - dim (-2) will be broadcasted
# with the num_models of the posterior --> querying all samples on all models
# posterior.mvn takes q-dimensional input by default, which removes the q-dim
# component_sample_probs: num_models x num_samples x batch_shape x num_models
component_sample_probs = posterior.mvn.log_prob(prev_samples).exp()

# average over mixture components
mixture_sample_probs = component_sample_probs.mean(dim=-1)

# this is the average over the model and sample dim
prev_entropy = -mixture_sample_probs.log().mean(dim=[0, 1])

# the posterior entropy is an average entropy over gaussians, so no mixture
post_entropy = -posterior.mvn.log_prob(samples.squeeze(-1)).mean(0)
bald = prev_entropy.unsqueeze(-1) - post_entropy
return bald
4 changes: 4 additions & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,9 +1678,13 @@ def construct_inputs_qJES(
def construct_inputs_BALD(
model: Model,
X_pending: Optional[Tensor] = None,
sampler: Optional[MCSampler] = None,
posterior_transform: Optional[PosteriorTransform] = None,
):
inputs = {
"model": model,
"X_pending": X_pending,
"sampler": sampler,
"posterior_transform": posterior_transform,
}
return inputs
88 changes: 62 additions & 26 deletions test/acquisition/test_bayesian_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,32 @@
from botorch.models import SingleTaskGP
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.transforms.outcome import Standardize
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase


def get_model(
train_X,
train_Y,
standardize_model,
**tkwargs,
):
num_objectives = train_Y.shape[-1]

if standardize_model:
outcome_transform = Standardize(m=num_objectives)
else:
outcome_transform = None

model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
outcome_transform=outcome_transform,
)

return model


def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):

mcmc_samples = {
Expand All @@ -28,7 +51,7 @@ def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
return mcmc_samples


def get_model(
def get_fully_bayesian_model(
train_X,
train_Y,
num_models,
Expand Down Expand Up @@ -72,21 +95,26 @@ def test_q_bayesian_active_learning_by_disagreement(self):
tkwargs = {"device": self.device}
num_objectives = 1
num_models = 3
input_dim = 2

X_pending_list = [None, torch.rand(2, input_dim)]
for (
dtype,
standardize_model,
infer_noise,
X_pending,
) in product(
(torch.float, torch.double),
(False, True), # standardize_model
(True,), # infer_noise - only one option avail in PyroModels
X_pending_list,
):
X_pending = X_pending.to(**tkwargs) if X_pending is not None else None
tkwargs["dtype"] = dtype
input_dim = 2
train_X = torch.rand(4, input_dim, **tkwargs)
train_Y = torch.rand(4, num_objectives, **tkwargs)

model = get_model(
model = get_fully_bayesian_model(
train_X,
train_Y,
num_models,
Expand All @@ -96,32 +124,40 @@ def test_q_bayesian_active_learning_by_disagreement(self):
)

# test acquisition
X_pending_list = [None, torch.rand(2, input_dim, **tkwargs)]
for i in range(len(X_pending_list)):
X_pending = X_pending_list[i]

acq = qBayesianActiveLearningByDisagreement(
model=model,
X_pending=X_pending,
)

test_Xs = [
torch.rand(4, 1, input_dim, **tkwargs),
torch.rand(4, 3, input_dim, **tkwargs),
torch.rand(4, 5, 1, input_dim, **tkwargs),
torch.rand(4, 5, 3, input_dim, **tkwargs),
]

for j in range(len(test_Xs)):
acq_X = acq.forward(test_Xs[j])
acq_X = acq(test_Xs[j])
# assess shape
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])
acq = qBayesianActiveLearningByDisagreement(
model=model,
X_pending=X_pending,
)

acq2 = qBayesianActiveLearningByDisagreement(
model=model, sampler=IIDNormalSampler(torch.Size([9]))
)
self.assertIsInstance(acq2.sampler, IIDNormalSampler)

test_Xs = [
torch.rand(4, 1, input_dim, **tkwargs),
torch.rand(4, 3, input_dim, **tkwargs),
torch.rand(4, 5, 1, input_dim, **tkwargs),
torch.rand(4, 5, 3, input_dim, **tkwargs),
torch.rand(5, 13, input_dim, **tkwargs),
]

for j in range(len(test_Xs)):
acq_X = acq.forward(test_Xs[j])
acq_X = acq(test_Xs[j])
# assess shape
self.assertTrue(acq_X.shape == test_Xs[j].shape[:-2])

self.assertTrue(torch.all(acq_X > 0))

# Support with non-fully bayesian models is not possible. Thus, we
# throw an error.
non_fully_bayesian_model = SingleTaskGP(train_X, train_Y)
with self.assertRaises(ValueError):
non_fully_bayesian_model = get_model(train_X, train_Y, False)
with self.assertRaisesRegex(
ValueError,
"Fully Bayesian acquisition functions require a "
"SaasFullyBayesianSingleTaskGP to run.",
):
acq = qBayesianActiveLearningByDisagreement(
model=non_fully_bayesian_model,
)