58
58
59
59
60
60
class MaxValueBase (AcquisitionFunction , ABC ):
61
- r"""Abstract base class for acquisition functions based on Max-value Entropy Search.
61
+ r"""Abstract base class for acquisition functions based on Max-value Entropy Search,
62
+ using discrete max posterior sampling.
62
63
63
64
This class provides the basic building blocks for constructing max-value
64
65
entropy-based acquisition functions along the lines of [Wang2017mves]_.
66
+ It provides basic functionality for sampling posterior maximum values from
67
+ a surrogate Gaussian process model using a discrete set of candidates. It supports
68
+ either exact (w.r.t. the candidate set) sampling, or using a Gumbel approximation.
65
69
66
- Subclasses need to implement `_sample_max_values` and _compute_information_gain`
67
- methods.
70
+ Subclasses must implement `_compute_information_gain`.
68
71
"""
69
72
70
73
def __init__ (
71
74
self ,
72
75
model : Model ,
73
- num_mv_samples : int ,
76
+ candidate_set : Tensor ,
77
+ num_mv_samples : int = 10 ,
74
78
posterior_transform : PosteriorTransform | None = None ,
79
+ use_gumbel : bool = True ,
75
80
maximize : bool = True ,
76
81
X_pending : Tensor | None = None ,
82
+ train_inputs : Tensor | None = None ,
77
83
) -> None :
78
- r"""Single-outcome max-value entropy search-based acquisition functions.
84
+ r"""Single-outcome max-value entropy search-based acquisition functions
85
+ based on discrete MV sampling.
79
86
80
87
Args:
81
88
model: A fitted single-outcome model.
89
+ candidate_set: A `n x d` Tensor including `n` candidate points to
90
+ discretize the design space. Max values are sampled from the
91
+ (joint) model posterior over these points.
82
92
num_mv_samples: Number of max value samples.
83
93
posterior_transform: A PosteriorTransform. If using a multi-output model,
84
94
a PosteriorTransform that transforms the multi-output posterior into a
85
95
single-output posterior is required.
96
+ use_gumbel: If True, use Gumbel approximation to sample the max values.
86
97
maximize: If True, consider the problem a maximization problem.
87
98
X_pending: A `m x d`-dim Tensor of `m` design points that have been
88
99
submitted for function evaluation but have not yet been evaluated.
100
+ train_inputs: A `n_train x d` Tensor that the model has been fitted on.
101
+ Not required if the model is an instance of a GPyTorch ExactGP model.
89
102
"""
90
103
super ().__init__ (model = model )
91
104
92
- if posterior_transform is None and model .num_outputs != 1 :
105
+ if model .num_outputs > 1 :
93
106
raise UnsupportedError (
94
- "Must specify a posterior transform when using a multi-output model ."
107
+ f"Multi-output models are not supported by { self . __class__ . __name__ } ."
95
108
)
109
+ if train_inputs is None and hasattr (model , "train_inputs" ):
110
+ train_inputs = model .train_inputs [0 ]
111
+ if train_inputs is not None :
112
+ if train_inputs .ndim > 2 :
113
+ raise NotImplementedError (
114
+ "Batched GP models (e.g., fantasized models) are not yet "
115
+ f"supported by `{ self .__class__ .__name__ } `."
116
+ )
117
+ train_inputs = match_batch_shape (train_inputs , candidate_set )
118
+ candidate_set = torch .cat ([candidate_set , train_inputs ], dim = 0 )
96
119
97
- # Batched GP models are not currently supported
98
- try :
99
- batch_shape = model .batch_shape
100
- except NotImplementedError :
101
- batch_shape = torch .Size ()
102
- if len (batch_shape ) > 0 :
103
- raise NotImplementedError (
104
- "Batched GP models (e.g., fantasized models) are not yet "
105
- f"supported by `{ self .__class__ .__name__ } `."
106
- )
120
+ self .candidate_set = candidate_set
107
121
self .num_mv_samples = num_mv_samples
108
122
self .posterior_transform = posterior_transform
123
+ self .use_gumbel = use_gumbel
109
124
self .maximize = maximize
110
125
self .weight = 1.0 if maximize else - 1.0
111
126
self .set_X_pending (X_pending )
@@ -151,106 +166,6 @@ def set_X_pending(self, X_pending: Tensor | None = None) -> None:
151
166
self ._sample_max_values (num_samples = self .num_mv_samples , X_pending = X_pending )
152
167
self .X_pending = X_pending
153
168
154
- # ------- Abstract methods that need to be implemented by subclasses ------- #
155
-
156
- @abstractmethod
157
- def _compute_information_gain (self , X : Tensor ) -> Tensor :
158
- r"""Compute the information gain at the design points `X`.
159
-
160
- `num_fantasies = 1` for non-fantasized models.
161
-
162
- Args:
163
- X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
164
- with `1` `d`-dim design point each.
165
-
166
- Returns:
167
- A `num_fantasies x batch_shape`-dim Tensor of information gains at the
168
- given design points `X` (`num_fantasies=1` for non-fantasized models).
169
- """
170
- pass # pragma: no cover
171
-
172
- @abstractmethod
173
- def _sample_max_values (
174
- self , num_samples : int , X_pending : Tensor | None = None
175
- ) -> None :
176
- r"""Draw samples from the posterior over maximum values.
177
-
178
- These samples are used to compute Monte Carlo approximations of expectations
179
- over the posterior over the function maximum. This function sets
180
- `self.posterior_max_values`.
181
-
182
- Args:
183
- num_samples: The number of samples to draw.
184
- X_pending: A `m x d`-dim Tensor of `m` design points that have been
185
- submitted for function evaluation but have not yet been evaluated.
186
-
187
- Returns:
188
- A `num_samples x num_fantasies` Tensor of posterior max value samples
189
- (`num_fantasies=1` for non-fantasized models).
190
- """
191
- pass # pragma: no cover
192
-
193
-
194
- class DiscreteMaxValueBase (MaxValueBase ):
195
- r"""Abstract base class for MES-like methods using discrete max posterior sampling.
196
-
197
- This class provides basic functionality for sampling posterior maximum values from
198
- a surrogate Gaussian process model using a discrete set of candidates. It supports
199
- either exact (w.r.t. the candidate set) sampling, or using a Gumbel approximation.
200
- """
201
-
202
- def __init__ (
203
- self ,
204
- model : Model ,
205
- candidate_set : Tensor ,
206
- num_mv_samples : int = 10 ,
207
- posterior_transform : PosteriorTransform | None = None ,
208
- use_gumbel : bool = True ,
209
- maximize : bool = True ,
210
- X_pending : Tensor | None = None ,
211
- train_inputs : Tensor | None = None ,
212
- ) -> None :
213
- r"""Single-outcome MES-like acquisition functions based on discrete MV sampling.
214
-
215
- Args:
216
- model: A fitted single-outcome model.
217
- candidate_set: A `n x d` Tensor including `n` candidate points to
218
- discretize the design space. Max values are sampled from the
219
- (joint) model posterior over these points.
220
- num_mv_samples: Number of max value samples.
221
- posterior_transform: A PosteriorTransform. If using a multi-output model,
222
- a PosteriorTransform that transforms the multi-output posterior into a
223
- single-output posterior is required.
224
- use_gumbel: If True, use Gumbel approximation to sample the max values.
225
- maximize: If True, consider the problem a maximization problem.
226
- X_pending: A `m x d`-dim Tensor of `m` design points that have been
227
- submitted for function evaluation but have not yet been evaluated.
228
- train_inputs: A `n_train x d` Tensor that the model has been fitted on.
229
- Not required if the model is an instance of a GPyTorch ExactGP model.
230
- """
231
- self .use_gumbel = use_gumbel
232
-
233
- if train_inputs is None and hasattr (model , "train_inputs" ):
234
- train_inputs = model .train_inputs [0 ]
235
- if train_inputs is not None :
236
- if train_inputs .ndim > 2 :
237
- raise NotImplementedError (
238
- "Batch GP models (e.g. fantasized models) "
239
- "are not yet supported by `MaxValueBase`"
240
- )
241
- train_inputs = match_batch_shape (train_inputs , candidate_set )
242
- candidate_set = torch .cat ([candidate_set , train_inputs ], dim = 0 )
243
-
244
- self .candidate_set = candidate_set
245
-
246
- super ().__init__ (
247
- model = model ,
248
- num_mv_samples = num_mv_samples ,
249
- posterior_transform = posterior_transform ,
250
- maximize = maximize ,
251
- X_pending = X_pending ,
252
- )
253
-
254
169
def _sample_max_values (
255
170
self , num_samples : int , X_pending : Tensor | None = None
256
171
) -> None :
@@ -291,13 +206,30 @@ def _sample_max_values(
291
206
self .posterior_max_values = sample_max_values (
292
207
model = self .model ,
293
208
candidate_set = candidate_set ,
294
- num_samples = self . num_mv_samples ,
209
+ num_samples = num_samples ,
295
210
posterior_transform = self .posterior_transform ,
296
211
maximize = self .maximize ,
297
212
)
298
213
214
+ # ------- Abstract methods that need to be implemented by subclasses ------- #
215
+
216
+ @abstractmethod
217
+ def _compute_information_gain (self , X : Tensor ) -> Tensor :
218
+ r"""Compute the information gain at the design points `X`.
219
+
220
+ `num_fantasies = 1` for non-fantasized models.
221
+
222
+ Args:
223
+ X: A `batch_shape x 1 x d`-dim Tensor of `batch_shape` t-batches
224
+ with `1` `d`-dim design point each.
299
225
300
- class qMaxValueEntropy (DiscreteMaxValueBase , MCSamplerMixin ):
226
+ Returns:
227
+ A `num_fantasies x batch_shape`-dim Tensor of information gains at the
228
+ given design points `X` (`num_fantasies=1` for non-fantasized models).
229
+ """
230
+
231
+
232
+ class qMaxValueEntropy (MaxValueBase , MCSamplerMixin ):
301
233
r"""The acquisition function for Max-value Entropy Search.
302
234
303
235
This acquisition function computes the mutual information of max values and
@@ -432,13 +364,14 @@ def _compute_information_gain(
432
364
)
433
365
# batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
434
366
mean_m = self .weight * posterior_m .mean .squeeze (- 1 )
435
- # batch_shape x num_fantasies x (m) x (1 + num_trace_observations)
367
+ # batch_shape x num_fantasies x (m)
368
+ # x (1 + num_trace_observations) x (1 + num_trace_observations)
436
369
variance_m = posterior_m .distribution .covariance_matrix
437
370
check_no_nans (variance_m )
438
371
439
372
# compute mean and std for fM|ym, x, Dt ~ N(u, s^2)
440
373
samples_m = self .weight * self .get_posterior_samples (posterior_m ).squeeze (- 1 )
441
- # s_m x batch_shape x num_fantasies x (m) (1 + num_trace_observations )
374
+ # s_m x batch_shape x num_fantasies x (m) x (1 + num_trace) x (1 + num_trace )
442
375
L = psd_safe_cholesky (variance_m )
443
376
temp_term = torch .cholesky_solve (covar_mM .unsqueeze (- 1 ), L ).transpose (- 2 , - 1 )
444
377
# equivalent to torch.matmul(covar_mM.unsqueeze(-2), torch.inverse(variance_m))
@@ -515,7 +448,7 @@ def _compute_information_gain(
515
448
return ig
516
449
517
450
518
- class qLowerBoundMaxValueEntropy (DiscreteMaxValueBase ):
451
+ class qLowerBoundMaxValueEntropy (MaxValueBase ):
519
452
r"""The acquisition function for General-purpose Information-Based
520
453
Bayesian Optimisation (GIBBON).
521
454
@@ -672,7 +605,7 @@ class qMultiFidelityMaxValueEntropy(qMaxValueEntropy):
672
605
for a detailed discussion of the basic ideas on multi-fidelity MES
673
606
(note that this implementation is somewhat different).
674
607
675
- The model must be single-outcome, unless using a PosteriorTransform .
608
+ The model must be single-outcome.
676
609
The batch case `q > 1` is supported through cyclic optimization and fantasies.
677
610
678
611
Example:
@@ -757,7 +690,7 @@ def __init__(
757
690
758
691
# resample max values after initializing self.project
759
692
# so that the max value samples are at the highest fidelity
760
- self ._sample_max_values (self .num_mv_samples )
693
+ self ._sample_max_values (num_samples = self .num_mv_samples )
761
694
762
695
@property
763
696
def cost_sampler (self ):
@@ -846,7 +779,7 @@ def __init__(
846
779
maximize : bool = True ,
847
780
cost_aware_utility : CostAwareUtility | None = None ,
848
781
project : Callable [[Tensor ], Tensor ] = lambda X : X ,
849
- expand : Callable [[Tensor ], Tensor ] = lambda X : X ,
782
+ expand : Callable [[Tensor ], Tensor ] | None = None ,
850
783
) -> None :
851
784
r"""Single-outcome max-value entropy search acquisition function.
852
785
@@ -878,7 +811,12 @@ def __init__(
878
811
a `batch_shape x (q + q_e)' x d`-dim output tensor, where the
879
812
`q_e` additional points in each q-batch correspond to
880
813
additional ("trace") observations.
814
+ NOTE: This is currently not supported. It leads to wrong outputs.
881
815
"""
816
+ if expand is not None :
817
+ raise UnsupportedError (
818
+ f"{ self .__class__ .__name__ } does not support trace observations. "
819
+ )
882
820
super ().__init__ (
883
821
model = model ,
884
822
candidate_set = candidate_set ,
@@ -890,7 +828,6 @@ def __init__(
890
828
maximize = maximize ,
891
829
cost_aware_utility = cost_aware_utility ,
892
830
project = project ,
893
- expand = expand ,
894
831
)
895
832
896
833
def _compute_information_gain (
@@ -1000,7 +937,7 @@ def _sample_max_value_Gumbel(
1000
937
quantiles = torch .zeros (num_fantasies , 3 , device = device , dtype = dtype )
1001
938
for i in range (num_fantasies ):
1002
939
lo_ , hi_ = lo [i ], hi [i ]
1003
- N = norm (mu [:, i ].cpu (). numpy (), sigma [:, i ].cpu (). numpy ())
940
+ N = norm (mu [:, i ].numpy (force = True ), sigma [:, i ].numpy (force = True ))
1004
941
quantiles [i , :] = torch .tensor (
1005
942
[
1006
943
brentq (lambda y : np .exp (np .sum (N .logcdf (y ))) - p , lo_ , hi_ )
0 commit comments