@@ -74,7 +74,6 @@ def __init__(
7474 posterior_transform : PosteriorTransform | None = None ,
7575 X_pending : Tensor | None = None ,
7676 estimation_type : str = "LB" ,
77- maximize : bool = True ,
7877 num_samples : int = 64 ,
7978 ) -> None :
8079 r"""Joint entropy search acquisition function.
@@ -91,11 +90,11 @@ def __init__(
9190 [Tu2022joint]_. These are sampled identically, so this only controls
9291 the fashion in which the GP is reshaped as a result of conditioning
9392 on the optimum.
93+ posterior_transform: PosteriorTransform to negate or scalarize the output.
9494 estimation_type: estimation_type: A string to determine which entropy
9595 estimate is computed: Lower bound" ("LB") or "Monte Carlo" ("MC").
9696 Lower Bound is recommended due to the relatively high variance
9797 of the MC estimator.
98- maximize: If true, we consider a maximization problem.
9998 X_pending: A `m x d`-dim Tensor of `m` design points that have been
10099 submitted for function evaluation, but have not yet been evaluated.
101100 num_samples: The number of Monte Carlo samples used for the Monte Carlo
@@ -112,16 +111,13 @@ def __init__(
112111 # and three-dimensional otherwise.
113112 self .optimal_inputs = optimal_inputs .unsqueeze (- 2 )
114113 self .optimal_outputs = optimal_outputs .unsqueeze (- 2 )
114+ self .optimal_output_values = (
115+ posterior_transform .evaluate (self .optimal_outputs ).unsqueeze (- 1 )
116+ if posterior_transform
117+ else self .optimal_outputs
118+ )
115119 self .posterior_transform = posterior_transform
116- self .maximize = maximize
117-
118- # The optima (can be maxima, can be minima) come in as the largest
119- # values if we optimize, or the smallest (likely substantially negative)
120- # if we minimize. Inside the acquisition function, however, we always
121- # want to consider MAX-values. As such, we need to flip them if
122- # we want to minimize.
123- if not self .maximize :
124- optimal_outputs = - optimal_outputs
120+
125121 self .num_samples = optimal_inputs .shape [0 ]
126122 self .condition_noiseless = condition_noiseless
127123 self .initial_model = model
@@ -203,7 +199,9 @@ def _compute_lower_bound_information_gain(
203199 A `batch_shape`-dim Tensor of acquisition values at the given design
204200 points `X`.
205201 """
206- initial_posterior = self .initial_model .posterior (X , observation_noise = True )
202+ initial_posterior = self .initial_model .posterior (
203+ X , observation_noise = True , posterior_transform = self .posterior_transform
204+ )
207205 # need to check if there is a two-dimensional batch shape -
208206 # the sampled optima appear in the dimension right after
209207 batch_shape = X .shape [:- 2 ]
@@ -221,15 +219,17 @@ def _compute_lower_bound_information_gain(
221219
222220 # Compute the mixture mean and variance
223221 posterior_m = self .conditional_model .posterior (
224- X .unsqueeze (MCMC_DIM ), observation_noise = True
222+ X .unsqueeze (MCMC_DIM ),
223+ observation_noise = True ,
224+ posterior_transform = self .posterior_transform ,
225225 )
226226 noiseless_var = self .conditional_model .posterior (
227- X .unsqueeze (MCMC_DIM ), observation_noise = False
227+ X .unsqueeze (MCMC_DIM ),
228+ observation_noise = False ,
229+ posterior_transform = self .posterior_transform ,
228230 ).variance
229231
230232 mean_m = posterior_m .mean
231- if not self .maximize :
232- mean_m = - mean_m
233233 variance_m = posterior_m .variance
234234
235235 check_no_nans (variance_m )
@@ -240,7 +240,7 @@ def _compute_lower_bound_information_gain(
240240 torch .zeros (1 , device = X .device , dtype = X .dtype ),
241241 torch .ones (1 , device = X .device , dtype = X .dtype ),
242242 )
243- normalized_mvs = (self .optimal_outputs - mean_m ) / stdv
243+ normalized_mvs = (self .optimal_output_values - mean_m ) / stdv
244244 cdf_mvs = normal .cdf (normalized_mvs ).clamp_min (CLAMP_LB )
245245 pdf_mvs = torch .exp (normal .log_prob (normalized_mvs ))
246246
@@ -294,7 +294,9 @@ def _compute_monte_carlo_information_gain(
294294 A `batch_shape`-dim Tensor of acquisition values at the given design
295295 points `X`.
296296 """
297- initial_posterior = self .initial_model .posterior (X , observation_noise = True )
297+ initial_posterior = self .initial_model .posterior (
298+ X , observation_noise = True , posterior_transform = self .posterior_transform
299+ )
298300
299301 batch_shape = X .shape [:- 2 ]
300302 sample_dim = len (batch_shape )
@@ -311,15 +313,17 @@ def _compute_monte_carlo_information_gain(
311313
312314 # Compute the mixture mean and variance
313315 posterior_m = self .conditional_model .posterior (
314- X .unsqueeze (MCMC_DIM ), observation_noise = True
316+ X .unsqueeze (MCMC_DIM ),
317+ observation_noise = True ,
318+ posterior_transform = self .posterior_transform ,
315319 )
316320 noiseless_var = self .conditional_model .posterior (
317- X .unsqueeze (MCMC_DIM ), observation_noise = False
321+ X .unsqueeze (MCMC_DIM ),
322+ observation_noise = False ,
323+ posterior_transform = self .posterior_transform ,
318324 ).variance
319325
320326 mean_m = posterior_m .mean
321- if not self .maximize :
322- mean_m = - mean_m
323327 variance_m = posterior_m .variance .clamp_min (CLAMP_LB )
324328 conditional_samples , conditional_logprobs = self ._compute_monte_carlo_variables (
325329 posterior_m
0 commit comments