Skip to content

Some clean-up for MES-based acqusition functions #2769

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 64 additions & 127 deletions botorch/acquisition/max_value_entropy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,54 +58,69 @@


class MaxValueBase(AcquisitionFunction, ABC):
r"""Abstract base class for acquisition functions based on Max-value Entropy Search.
r"""Abstract base class for acquisition functions based on Max-value Entropy Search,
using discrete max posterior sampling.

This class provides the basic building blocks for constructing max-value
entropy-based acquisition functions along the lines of [Wang2017mves]_.
It provides basic functionality for sampling posterior maximum values from
a surrogate Gaussian process model using a discrete set of candidates. It supports
either exact (w.r.t. the candidate set) sampling, or using a Gumbel approximation.

Subclasses need to implement `_sample_max_values` and _compute_information_gain`
methods.
Subclasses must implement `_compute_information_gain`.
"""

def __init__(
self,
model: Model,
num_mv_samples: int,
candidate_set: Tensor,
num_mv_samples: int = 10,
posterior_transform: PosteriorTransform | None = None,
use_gumbel: bool = True,
maximize: bool = True,
X_pending: Tensor | None = None,
train_inputs: Tensor | None = None,
) -> None:
r"""Single-outcome max-value entropy search-based acquisition functions.
r"""Single-outcome max-value entropy search-based acquisition functions
based on discrete MV sampling.

Args:
model: A fitted single-outcome model.
candidate_set: A `n x d` Tensor including `n` candidate points to
discretize the design space. Max values are sampled from the
(joint) model posterior over these points.
num_mv_samples: Number of max value samples.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
use_gumbel: If True, use Gumbel approximation to sample the max values.
maximize: If True, consider the problem a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation but have not yet been evaluated.
train_inputs: A `n_train x d` Tensor that the model has been fitted on.
Not required if the model is an instance of a GPyTorch ExactGP model.
"""
super().__init__(model=model)

if posterior_transform is None and model.num_outputs != 1:
if model.num_outputs > 1:
raise UnsupportedError(
"Must specify a posterior transform when using a multi-output model."
f"Multi-output models are not supported by {self.__class__.__name__}."
)
if train_inputs is None and hasattr(model, "train_inputs"):
train_inputs = model.train_inputs[0]
if train_inputs is not None:
if train_inputs.ndim > 2:
raise NotImplementedError(
"Batched GP models (e.g., fantasized models) are not yet "
f"supported by `{self.__class__.__name__}`."
)
train_inputs = match_batch_shape(train_inputs, candidate_set)
candidate_set = torch.cat([candidate_set, train_inputs], dim=0)

# Batched GP models are not currently supported
try:
batch_shape = model.batch_shape
except NotImplementedError:
batch_shape = torch.Size()
if len(batch_shape) > 0:
raise NotImplementedError(
"Batched GP models (e.g., fantasized models) are not yet "
f"supported by `{self.__class__.__name__}`."
)
self.candidate_set = candidate_set
self.num_mv_samples = num_mv_samples
self.posterior_transform = posterior_transform
self.use_gumbel = use_gumbel
self.maximize = maximize
self.weight = 1.0 if maximize else -1.0
self.set_X_pending(X_pending)
Expand Down Expand Up @@ -151,106 +166,6 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
self._sample_max_values(num_samples=self.num_mv_samples, X_pending=X_pending)
self.X_pending = X_pending

# ------- Abstract methods that need to be implemented by subclasses ------- #

@abstractmethod
def _compute_information_gain(self, X: Tensor) -> Tensor:
r"""Compute the information gain at the design points `X`.

`num_fantasies = 1` for non-fantasized models.

Args:
X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
with `1` `d`-dim design point each.

Returns:
A `num_fantasies x batch_shape`-dim Tensor of information gains at the
given design points `X` (`num_fantasies=1` for non-fantasized models).
"""
pass # pragma: no cover

@abstractmethod
def _sample_max_values(
self, num_samples: int, X_pending: Tensor | None = None
) -> None:
r"""Draw samples from the posterior over maximum values.

These samples are used to compute Monte Carlo approximations of expectations
over the posterior over the function maximum. This function sets
`self.posterior_max_values`.

Args:
num_samples: The number of samples to draw.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation but have not yet been evaluated.

Returns:
A `num_samples x num_fantasies` Tensor of posterior max value samples
(`num_fantasies=1` for non-fantasized models).
"""
pass # pragma: no cover


class DiscreteMaxValueBase(MaxValueBase):
r"""Abstract base class for MES-like methods using discrete max posterior sampling.

This class provides basic functionality for sampling posterior maximum values from
a surrogate Gaussian process model using a discrete set of candidates. It supports
either exact (w.r.t. the candidate set) sampling, or using a Gumbel approximation.
"""

def __init__(
self,
model: Model,
candidate_set: Tensor,
num_mv_samples: int = 10,
posterior_transform: PosteriorTransform | None = None,
use_gumbel: bool = True,
maximize: bool = True,
X_pending: Tensor | None = None,
train_inputs: Tensor | None = None,
) -> None:
r"""Single-outcome MES-like acquisition functions based on discrete MV sampling.

Args:
model: A fitted single-outcome model.
candidate_set: A `n x d` Tensor including `n` candidate points to
discretize the design space. Max values are sampled from the
(joint) model posterior over these points.
num_mv_samples: Number of max value samples.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
use_gumbel: If True, use Gumbel approximation to sample the max values.
maximize: If True, consider the problem a maximization problem.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation but have not yet been evaluated.
train_inputs: A `n_train x d` Tensor that the model has been fitted on.
Not required if the model is an instance of a GPyTorch ExactGP model.
"""
self.use_gumbel = use_gumbel

if train_inputs is None and hasattr(model, "train_inputs"):
train_inputs = model.train_inputs[0]
if train_inputs is not None:
if train_inputs.ndim > 2:
raise NotImplementedError(
"Batch GP models (e.g. fantasized models) "
"are not yet supported by `MaxValueBase`"
)
train_inputs = match_batch_shape(train_inputs, candidate_set)
candidate_set = torch.cat([candidate_set, train_inputs], dim=0)

self.candidate_set = candidate_set

super().__init__(
model=model,
num_mv_samples=num_mv_samples,
posterior_transform=posterior_transform,
maximize=maximize,
X_pending=X_pending,
)

def _sample_max_values(
self, num_samples: int, X_pending: Tensor | None = None
) -> None:
Expand Down Expand Up @@ -291,13 +206,30 @@ def _sample_max_values(
self.posterior_max_values = sample_max_values(
model=self.model,
candidate_set=candidate_set,
num_samples=self.num_mv_samples,
num_samples=num_samples,
posterior_transform=self.posterior_transform,
maximize=self.maximize,
)

# ------- Abstract methods that need to be implemented by subclasses ------- #

@abstractmethod
def _compute_information_gain(self, X: Tensor) -> Tensor:
r"""Compute the information gain at the design points `X`.

`num_fantasies = 1` for non-fantasized models.

Args:
X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
with `1` `d`-dim design point each.

class qMaxValueEntropy(DiscreteMaxValueBase, MCSamplerMixin):
Returns:
A `num_fantasies x batch_shape`-dim Tensor of information gains at the
given design points `X` (`num_fantasies=1` for non-fantasized models).
"""


class qMaxValueEntropy(MaxValueBase, MCSamplerMixin):
r"""The acquisition function for Max-value Entropy Search.

This acquisition function computes the mutual information of max values and
Expand Down Expand Up @@ -432,13 +364,14 @@ def _compute_information_gain(
)
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
mean_m = self.weight * posterior_m.mean.squeeze(-1)
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
# batch_shape x num_fantasies x (m)
# x (1 + num_trace_observations) x (1 + num_trace_observations)
variance_m = posterior_m.distribution.covariance_matrix
check_no_nans(variance_m)

# compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
samples_m = self.weight * self.get_posterior_samples(posterior_m).squeeze(-1)
# s_m x batch_shape x num_fantasies x (m) (1 + num_trace_observations)
# s_m x batch_shape x num_fantasies x (m) x (1 + num_trace) x (1 + num_trace)
L = psd_safe_cholesky(variance_m)
temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1), L).transpose(-2, -1)
# equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m))
Expand Down Expand Up @@ -515,7 +448,7 @@ def _compute_information_gain(
return ig


class qLowerBoundMaxValueEntropy(DiscreteMaxValueBase):
class qLowerBoundMaxValueEntropy(MaxValueBase):
r"""The acquisition function for General-purpose Information-Based
Bayesian Optimisation (GIBBON).

Expand Down Expand Up @@ -672,7 +605,7 @@ class qMultiFidelityMaxValueEntropy(qMaxValueEntropy):
for a detailed discussion of the basic ideas on multi-fidelity MES
(note that this implementation is somewhat different).

The model must be single-outcome, unless using a PosteriorTransform.
The model must be single-outcome.
The batch case `q > 1` is supported through cyclic optimization and fantasies.

Example:
Expand Down Expand Up @@ -757,7 +690,7 @@ def __init__(

# resample max values after initializing self.project
# so that the max value samples are at the highest fidelity
self._sample_max_values(self.num_mv_samples)
self._sample_max_values(num_samples=self.num_mv_samples)

@property
def cost_sampler(self):
Expand Down Expand Up @@ -846,7 +779,7 @@ def __init__(
maximize: bool = True,
cost_aware_utility: CostAwareUtility | None = None,
project: Callable[[Tensor], Tensor] = lambda X: X,
expand: Callable[[Tensor], Tensor] = lambda X: X,
expand: Callable[[Tensor], Tensor] | None = None,
) -> None:
r"""Single-outcome max-value entropy search acquisition function.

Expand Down Expand Up @@ -878,7 +811,12 @@ def __init__(
a `batch_shape x (q + q_e)' x d`-dim output tensor, where the
`q_e` additional points in each q-batch correspond to
additional ("trace") observations.
NOTE: This is currently not supported. It leads to wrong outputs.
"""
if expand is not None:
raise UnsupportedError(
f"{self.__class__.__name__} does not support trace observations. "
)
super().__init__(
model=model,
candidate_set=candidate_set,
Expand All @@ -890,7 +828,6 @@ def __init__(
maximize=maximize,
cost_aware_utility=cost_aware_utility,
project=project,
expand=expand,
)

def _compute_information_gain(
Expand Down Expand Up @@ -1000,7 +937,7 @@ def _sample_max_value_Gumbel(
quantiles = torch.zeros(num_fantasies, 3, device=device, dtype=dtype)
for i in range(num_fantasies):
lo_, hi_ = lo[i], hi[i]
N = norm(mu[:, i].cpu().numpy(), sigma[:, i].cpu().numpy())
N = norm(mu[:, i].numpy(force=True), sigma[:, i].numpy(force=True))
quantiles[i, :] = torch.tensor(
[
brentq(lambda y: np.exp(np.sum(N.logcdf(y))) - p, lo_, hi_)
Expand Down
13 changes: 3 additions & 10 deletions botorch/acquisition/multi_objective/max_value_entropy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import annotations

from collections.abc import Callable

from math import pi

import torch
Expand Down Expand Up @@ -139,19 +138,13 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
sampler=self.fantasies_sampler,
)
self.mo_model = fantasy_model
# convert model to batched single outcome model.
self.model = batched_multi_output_to_single_output(
batch_mo_model=self.mo_model
)
self._sample_max_values()
else:
# This is mainly for setting the model to the original model
# after the sequential optimization at q > 1
self.mo_model = self._init_model
self.model = batched_multi_output_to_single_output(
batch_mo_model=self.mo_model
)
self._sample_max_values()
# convert model to batched single outcome model.
self.model = batched_multi_output_to_single_output(batch_mo_model=self.mo_model)
self._sample_max_values()

def _sample_max_values(self) -> None:
"""Sample max values for MC approximation of the expectation in MES.
Expand Down
Loading