Skip to content

Commit c773094

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Some clean-up for MES-based acqusition functions (#2769)
Summary: - Merges `DiscreteMaxValueBase` into `MaxValueBase`. An abstract base class that only has one subclass, which is also an abstract base class, doesn't really provide much value and creates unreachable code that can only be tested by defining dummy subclasses. - Adds an explicit error when multi-output models are used. Both batched `SingleTaskGP` and `ModelListGP` would error out with different reasons. I suspect the underlying code supports it but e2e support needs some modifications. - Errors out if `expand` is provided to `qMultiFidelityLowerBoundMaxValueEntropy`. It produces outputs of wrong shape, which points to some missing handling of the different tensor shapes within the underlying code. - Adds a very basic test for `expand` with `qMultiFidelityMaxValueEntropy`, which produces output of correct shape. - Updates tests to use actual models rather than mock models. Testing that some functionality (like multi-output model support) works with a mock model doesn't actually mean it works. We should be using mock models a lot more sparingly in tests. Reviewed By: esantorella Differential Revision: D71051750
1 parent 641b16f commit c773094

File tree

3 files changed

+175
-322
lines changed

3 files changed

+175
-322
lines changed

botorch/acquisition/max_value_entropy_search.py

Lines changed: 64 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -58,54 +58,69 @@
5858

5959

6060
class MaxValueBase(AcquisitionFunction, ABC):
61-
r"""Abstract base class for acquisition functions based on Max-value Entropy Search.
61+
r"""Abstract base class for acquisition functions based on Max-value Entropy Search,
62+
using discrete max posterior sampling.
6263
6364
This class provides the basic building blocks for constructing max-value
6465
entropy-based acquisition functions along the lines of [Wang2017mves]_.
66+
It provides basic functionality for sampling posterior maximum values from
67+
a surrogate Gaussian process model using a discrete set of candidates. It supports
68+
either exact (w.r.t. the candidate set) sampling, or using a Gumbel approximation.
6569
66-
Subclasses need to implement `_sample_max_values` and _compute_information_gain`
67-
methods.
70+
Subclasses must implement `_compute_information_gain`.
6871
"""
6972

7073
def __init__(
7174
self,
7275
model: Model,
73-
num_mv_samples: int,
76+
candidate_set: Tensor,
77+
num_mv_samples: int = 10,
7478
posterior_transform: PosteriorTransform | None = None,
79+
use_gumbel: bool = True,
7580
maximize: bool = True,
7681
X_pending: Tensor | None = None,
82+
train_inputs: Tensor | None = None,
7783
) -> None:
78-
r"""Single-outcome max-value entropy search-based acquisition functions.
84+
r"""Single-outcome max-value entropy search-based acquisition functions
85+
based on discrete MV sampling.
7986
8087
Args:
8188
model: A fitted single-outcome model.
89+
candidate_set: A `n x d` Tensor including `n` candidate points to
90+
discretize the design space. Max values are sampled from the
91+
(joint) model posterior over these points.
8292
num_mv_samples: Number of max value samples.
8393
posterior_transform: A PosteriorTransform. If using a multi-output model,
8494
a PosteriorTransform that transforms the multi-output posterior into a
8595
single-output posterior is required.
96+
use_gumbel: If True, use Gumbel approximation to sample the max values.
8697
maximize: If True, consider the problem a maximization problem.
8798
X_pending: A `m x d`-dim Tensor of `m` design points that have been
8899
submitted for function evaluation but have not yet been evaluated.
100+
train_inputs: A `n_train x d` Tensor that the model has been fitted on.
101+
Not required if the model is an instance of a GPyTorch ExactGP model.
89102
"""
90103
super().__init__(model=model)
91104

92-
if posterior_transform is None and model.num_outputs != 1:
105+
if model.num_outputs > 1:
93106
raise UnsupportedError(
94-
"Must specify a posterior transform when using a multi-output model."
107+
f"Multi-output models are not supported by {self.__class__.__name__}."
95108
)
109+
if train_inputs is None and hasattr(model, "train_inputs"):
110+
train_inputs = model.train_inputs[0]
111+
if train_inputs is not None:
112+
if train_inputs.ndim > 2:
113+
raise NotImplementedError(
114+
"Batched GP models (e.g., fantasized models) are not yet "
115+
f"supported by `{self.__class__.__name__}`."
116+
)
117+
train_inputs = match_batch_shape(train_inputs, candidate_set)
118+
candidate_set = torch.cat([candidate_set, train_inputs], dim=0)
96119

97-
# Batched GP models are not currently supported
98-
try:
99-
batch_shape = model.batch_shape
100-
except NotImplementedError:
101-
batch_shape = torch.Size()
102-
if len(batch_shape) > 0:
103-
raise NotImplementedError(
104-
"Batched GP models (e.g., fantasized models) are not yet "
105-
f"supported by `{self.__class__.__name__}`."
106-
)
120+
self.candidate_set = candidate_set
107121
self.num_mv_samples = num_mv_samples
108122
self.posterior_transform = posterior_transform
123+
self.use_gumbel = use_gumbel
109124
self.maximize = maximize
110125
self.weight = 1.0 if maximize else -1.0
111126
self.set_X_pending(X_pending)
@@ -151,106 +166,6 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
151166
self._sample_max_values(num_samples=self.num_mv_samples, X_pending=X_pending)
152167
self.X_pending = X_pending
153168

154-
# ------- Abstract methods that need to be implemented by subclasses ------- #
155-
156-
@abstractmethod
157-
def _compute_information_gain(self, X: Tensor) -> Tensor:
158-
r"""Compute the information gain at the design points `X`.
159-
160-
`num_fantasies = 1` for non-fantasized models.
161-
162-
Args:
163-
X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
164-
with `1` `d`-dim design point each.
165-
166-
Returns:
167-
A `num_fantasies x batch_shape`-dim Tensor of information gains at the
168-
given design points `X` (`num_fantasies=1` for non-fantasized models).
169-
"""
170-
pass # pragma: no cover
171-
172-
@abstractmethod
173-
def _sample_max_values(
174-
self, num_samples: int, X_pending: Tensor | None = None
175-
) -> None:
176-
r"""Draw samples from the posterior over maximum values.
177-
178-
These samples are used to compute Monte Carlo approximations of expectations
179-
over the posterior over the function maximum. This function sets
180-
`self.posterior_max_values`.
181-
182-
Args:
183-
num_samples: The number of samples to draw.
184-
X_pending: A `m x d`-dim Tensor of `m` design points that have been
185-
submitted for function evaluation but have not yet been evaluated.
186-
187-
Returns:
188-
A `num_samples x num_fantasies` Tensor of posterior max value samples
189-
(`num_fantasies=1` for non-fantasized models).
190-
"""
191-
pass # pragma: no cover
192-
193-
194-
class DiscreteMaxValueBase(MaxValueBase):
195-
r"""Abstract base class for MES-like methods using discrete max posterior sampling.
196-
197-
This class provides basic functionality for sampling posterior maximum values from
198-
a surrogate Gaussian process model using a discrete set of candidates. It supports
199-
either exact (w.r.t. the candidate set) sampling, or using a Gumbel approximation.
200-
"""
201-
202-
def __init__(
203-
self,
204-
model: Model,
205-
candidate_set: Tensor,
206-
num_mv_samples: int = 10,
207-
posterior_transform: PosteriorTransform | None = None,
208-
use_gumbel: bool = True,
209-
maximize: bool = True,
210-
X_pending: Tensor | None = None,
211-
train_inputs: Tensor | None = None,
212-
) -> None:
213-
r"""Single-outcome MES-like acquisition functions based on discrete MV sampling.
214-
215-
Args:
216-
model: A fitted single-outcome model.
217-
candidate_set: A `n x d` Tensor including `n` candidate points to
218-
discretize the design space. Max values are sampled from the
219-
(joint) model posterior over these points.
220-
num_mv_samples: Number of max value samples.
221-
posterior_transform: A PosteriorTransform. If using a multi-output model,
222-
a PosteriorTransform that transforms the multi-output posterior into a
223-
single-output posterior is required.
224-
use_gumbel: If True, use Gumbel approximation to sample the max values.
225-
maximize: If True, consider the problem a maximization problem.
226-
X_pending: A `m x d`-dim Tensor of `m` design points that have been
227-
submitted for function evaluation but have not yet been evaluated.
228-
train_inputs: A `n_train x d` Tensor that the model has been fitted on.
229-
Not required if the model is an instance of a GPyTorch ExactGP model.
230-
"""
231-
self.use_gumbel = use_gumbel
232-
233-
if train_inputs is None and hasattr(model, "train_inputs"):
234-
train_inputs = model.train_inputs[0]
235-
if train_inputs is not None:
236-
if train_inputs.ndim > 2:
237-
raise NotImplementedError(
238-
"Batch GP models (e.g. fantasized models) "
239-
"are not yet supported by `MaxValueBase`"
240-
)
241-
train_inputs = match_batch_shape(train_inputs, candidate_set)
242-
candidate_set = torch.cat([candidate_set, train_inputs], dim=0)
243-
244-
self.candidate_set = candidate_set
245-
246-
super().__init__(
247-
model=model,
248-
num_mv_samples=num_mv_samples,
249-
posterior_transform=posterior_transform,
250-
maximize=maximize,
251-
X_pending=X_pending,
252-
)
253-
254169
def _sample_max_values(
255170
self, num_samples: int, X_pending: Tensor | None = None
256171
) -> None:
@@ -291,13 +206,30 @@ def _sample_max_values(
291206
self.posterior_max_values = sample_max_values(
292207
model=self.model,
293208
candidate_set=candidate_set,
294-
num_samples=self.num_mv_samples,
209+
num_samples=num_samples,
295210
posterior_transform=self.posterior_transform,
296211
maximize=self.maximize,
297212
)
298213

214+
# ------- Abstract methods that need to be implemented by subclasses ------- #
215+
216+
@abstractmethod
217+
def _compute_information_gain(self, X: Tensor) -> Tensor:
218+
r"""Compute the information gain at the design points `X`.
219+
220+
`num_fantasies = 1` for non-fantasized models.
221+
222+
Args:
223+
X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
224+
with `1` `d`-dim design point each.
299225
300-
class qMaxValueEntropy(DiscreteMaxValueBase, MCSamplerMixin):
226+
Returns:
227+
A `num_fantasies x batch_shape`-dim Tensor of information gains at the
228+
given design points `X` (`num_fantasies=1` for non-fantasized models).
229+
"""
230+
231+
232+
class qMaxValueEntropy(MaxValueBase, MCSamplerMixin):
301233
r"""The acquisition function for Max-value Entropy Search.
302234
303235
This acquisition function computes the mutual information of max values and
@@ -432,13 +364,14 @@ def _compute_information_gain(
432364
)
433365
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
434366
mean_m = self.weight * posterior_m.mean.squeeze(-1)
435-
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
367+
# batch_shape x num_fantasies x (m)
368+
# x (1 + num_trace_observations) x (1 + num_trace_observations)
436369
variance_m = posterior_m.distribution.covariance_matrix
437370
check_no_nans(variance_m)
438371

439372
# compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
440373
samples_m = self.weight * self.get_posterior_samples(posterior_m).squeeze(-1)
441-
# s_m x batch_shape x num_fantasies x (m) (1 + num_trace_observations)
374+
# s_m x batch_shape x num_fantasies x (m) x (1 + num_trace) x (1 + num_trace)
442375
L = psd_safe_cholesky(variance_m)
443376
temp_term = torch.cholesky_solve(covar_mM.unsqueeze(-1), L).transpose(-2, -1)
444377
# equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m))
@@ -515,7 +448,7 @@ def _compute_information_gain(
515448
return ig
516449

517450

518-
class qLowerBoundMaxValueEntropy(DiscreteMaxValueBase):
451+
class qLowerBoundMaxValueEntropy(MaxValueBase):
519452
r"""The acquisition function for General-purpose Information-Based
520453
Bayesian Optimisation (GIBBON).
521454
@@ -672,7 +605,7 @@ class qMultiFidelityMaxValueEntropy(qMaxValueEntropy):
672605
for a detailed discussion of the basic ideas on multi-fidelity MES
673606
(note that this implementation is somewhat different).
674607
675-
The model must be single-outcome, unless using a PosteriorTransform.
608+
The model must be single-outcome.
676609
The batch case `q > 1` is supported through cyclic optimization and fantasies.
677610
678611
Example:
@@ -757,7 +690,7 @@ def __init__(
757690

758691
# resample max values after initializing self.project
759692
# so that the max value samples are at the highest fidelity
760-
self._sample_max_values(self.num_mv_samples)
693+
self._sample_max_values(num_samples=self.num_mv_samples)
761694

762695
@property
763696
def cost_sampler(self):
@@ -846,7 +779,7 @@ def __init__(
846779
maximize: bool = True,
847780
cost_aware_utility: CostAwareUtility | None = None,
848781
project: Callable[[Tensor], Tensor] = lambda X: X,
849-
expand: Callable[[Tensor], Tensor] = lambda X: X,
782+
expand: Callable[[Tensor], Tensor] | None = None,
850783
) -> None:
851784
r"""Single-outcome max-value entropy search acquisition function.
852785
@@ -878,7 +811,12 @@ def __init__(
878811
a `batch_shape x (q + q_e)' x d`-dim output tensor, where the
879812
`q_e` additional points in each q-batch correspond to
880813
additional ("trace") observations.
814+
NOTE: This is currently not supported. It leads to wrong outputs.
881815
"""
816+
if expand is not None:
817+
raise UnsupportedError(
818+
f"{self.__class__.__name__} does not support trace observations. "
819+
)
882820
super().__init__(
883821
model=model,
884822
candidate_set=candidate_set,
@@ -890,7 +828,6 @@ def __init__(
890828
maximize=maximize,
891829
cost_aware_utility=cost_aware_utility,
892830
project=project,
893-
expand=expand,
894831
)
895832

896833
def _compute_information_gain(
@@ -1000,7 +937,7 @@ def _sample_max_value_Gumbel(
1000937
quantiles = torch.zeros(num_fantasies, 3, device=device, dtype=dtype)
1001938
for i in range(num_fantasies):
1002939
lo_, hi_ = lo[i], hi[i]
1003-
N = norm(mu[:, i].cpu().numpy(), sigma[:, i].cpu().numpy())
940+
N = norm(mu[:, i].numpy(force=True), sigma[:, i].numpy(force=True))
1004941
quantiles[i, :] = torch.tensor(
1005942
[
1006943
brentq(lambda y: np.exp(np.sum(N.logcdf(y))) - p, lo_, hi_)

botorch/acquisition/multi_objective/max_value_entropy_search.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from __future__ import annotations
2121

2222
from collections.abc import Callable
23-
2423
from math import pi
2524

2625
import torch
@@ -139,19 +138,13 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
139138
sampler=self.fantasies_sampler,
140139
)
141140
self.mo_model = fantasy_model
142-
# convert model to batched single outcome model.
143-
self.model = batched_multi_output_to_single_output(
144-
batch_mo_model=self.mo_model
145-
)
146-
self._sample_max_values()
147141
else:
148142
# This is mainly for setting the model to the original model
149143
# after the sequential optimization at q > 1
150144
self.mo_model = self._init_model
151-
self.model = batched_multi_output_to_single_output(
152-
batch_mo_model=self.mo_model
153-
)
154-
self._sample_max_values()
145+
# convert model to batched single outcome model.
146+
self.model = batched_multi_output_to_single_output(batch_mo_model=self.mo_model)
147+
self._sample_max_values()
155148

156149
def _sample_max_values(self) -> None:
157150
"""Sample max values for MC approximation of the expectation in MES.

0 commit comments

Comments
 (0)