@@ -250,6 +250,11 @@ class qLogNoisyExpectedImprovement(
250250
251251 where `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`.
252252
253+ For optimizing a batch of `q > 1` points using sequential greedy optimization,
254+ the incremental improvement from the latest point is computed and returned by
255+ default. I.e. the pending points are treated X_baseline. Often, the incremental
256+ EI is easier to optimize.
257+
253258 Example:
254259 >>> model = SingleTaskGP(train_X, train_Y)
255260 >>> sampler = SobolQMCNormalSampler(1024)
@@ -273,6 +278,7 @@ def __init__(
273278 tau_max : float = TAU_MAX ,
274279 tau_relu : float = TAU_RELU ,
275280 marginalize_dim : int | None = None ,
281+ incremental : bool = True ,
276282 ) -> None :
277283 r"""q-Noisy Expected Improvement.
278284
@@ -312,6 +318,9 @@ def __init__(
312318 tau_relu: Temperature parameter controlling the sharpness of the smooth
313319 approximations to ReLU.
314320 marginalize_dim: The dimension to marginalize over.
321+ incremental: Whether to compute incremental EI over the pending points
322+ or compute EI of the joint batch improvement (including pending
323+ points).
315324
316325 TODO: similar to qNEHVI, when we are using sequential greedy candidate
317326 selection, we could incorporate pending points X_baseline and compute
@@ -320,27 +329,34 @@ def __init__(
320329 """
321330 # TODO: separate out baseline variables initialization and other functions
322331 # in qNEI to avoid duplication of both code and work at runtime.
332+ self .incremental = incremental
333+
323334 super ().__init__ (
324335 model = model ,
325336 sampler = sampler ,
326337 objective = objective ,
327338 posterior_transform = posterior_transform ,
328- X_pending = X_pending ,
339+ # we set X_pending in init_baseline for incremental NEI
340+ X_pending = X_pending if not incremental else None ,
329341 constraints = constraints ,
330342 eta = eta ,
331343 fat = fat ,
332344 tau_max = tau_max ,
333345 )
334346 self .tau_relu = tau_relu
347+ self .prune_baseline = prune_baseline
348+ self .marginalize_dim = marginalize_dim
349+ if incremental :
350+ self .X_pending = None # required to initialize attribute for optimize_acqf
335351 self ._init_baseline (
336352 model = model ,
337353 X_baseline = X_baseline ,
354+ # This is ignored in incremental=False
355+ X_pending = X_pending ,
338356 sampler = sampler ,
339357 objective = objective ,
340358 posterior_transform = posterior_transform ,
341- prune_baseline = prune_baseline ,
342359 cache_root = cache_root ,
343- marginalize_dim = marginalize_dim ,
344360 )
345361
346362 def _sample_forward (self , obj : Tensor ) -> Tensor :
@@ -364,26 +380,34 @@ def _init_baseline(
364380 self ,
365381 model : Model ,
366382 X_baseline : Tensor ,
383+ X_pending : Tensor | None = None ,
367384 sampler : MCSampler | None = None ,
368385 objective : MCAcquisitionObjective | None = None ,
369386 posterior_transform : PosteriorTransform | None = None ,
370- prune_baseline : bool = False ,
371387 cache_root : bool = True ,
372- marginalize_dim : int | None = None ,
373388 ) -> None :
374389 CachedCholeskyMCSamplerMixin .__init__ (
375390 self , model = model , cache_root = cache_root , sampler = sampler
376391 )
377- if prune_baseline :
392+ if self . prune_baseline :
378393 X_baseline = prune_inferior_points (
379394 model = model ,
380395 X = X_baseline ,
381396 objective = objective ,
382397 posterior_transform = posterior_transform ,
383- marginalize_dim = marginalize_dim ,
398+ marginalize_dim = self . marginalize_dim ,
384399 constraints = self ._constraints ,
385400 )
386- self .register_buffer ("X_baseline" , X_baseline )
401+ self .register_buffer ("_X_baseline" , X_baseline )
402+ # full_X_baseline is the set of points that should be considered as the
403+ # incumbent. For incremental EI, this contains the previously evaluated
404+ # points (X_baseline) and pending points (X_pending). For non-incremental
405+ # EI, this contains the previously evaluated points (X_baseline).
406+ if X_pending is not None and self .incremental :
407+ full_X_baseline = torch .cat ([X_baseline , X_pending ], dim = - 2 )
408+ else :
409+ full_X_baseline = X_baseline
410+ self .register_buffer ("_full_X_baseline" , full_X_baseline )
387411 # registering buffers for _get_samples_and_objectives in the next `if` block
388412 self .register_buffer ("baseline_samples" , None )
389413 self .register_buffer ("baseline_obj" , None )
@@ -392,7 +416,7 @@ def _init_baseline(
392416 # set baseline samples
393417 with torch .no_grad (): # this is _get_samples_and_objectives(X_baseline)
394418 posterior = self .model .posterior (
395- X_baseline , posterior_transform = self .posterior_transform
419+ self . X_baseline , posterior_transform = self .posterior_transform
396420 )
397421 # Note: The root decomposition is cached in two different places. It
398422 # may be confusing to have two different caches, but this is not
@@ -404,7 +428,9 @@ def _init_baseline(
404428 # - self._baseline_L allows a root decomposition to be persisted outside
405429 # this method.
406430 self .baseline_samples = self .get_posterior_samples (posterior )
407- self .baseline_obj = self .objective (self .baseline_samples , X = X_baseline )
431+ self .baseline_obj = self .objective (
432+ self .baseline_samples , X = self .X_baseline
433+ )
408434
409435 # We make a copy here because we will write an attribute `base_samples`
410436 # to `self.base_sampler.base_samples`, and we don't want to mutate
@@ -418,6 +444,46 @@ def _init_baseline(
418444 )
419445 self ._baseline_L = self ._compute_root_decomposition (posterior = posterior )
420446
447+ @property
448+ def X_baseline (self ) -> Tensor :
449+ """Returns the set of pointsthat should be considered as the incumbent.
450+
451+ For incremental EI, this contains the previously evaluated points
452+ (X_baseline) and pending points (X_pending). For non-incremental
453+ EI, this contains the previously evaluated points (X_baseline).
454+ """
455+ return self ._full_X_baseline
456+
457+ def set_X_pending (self , X_pending : Tensor | None = None ) -> None :
458+ r"""Informs the acquisition function about pending design points.
459+
460+ Here pending points are concatenated with X_baseline and incremental
461+ NEI is computed.
462+
463+ Args:
464+ X_pending: `n x d` Tensor with `n` `d`-dim design points that have
465+ been submitted for evaluation but have not yet been evaluated.
466+ """
467+ if not self .incremental :
468+ return super ().set_X_pending (X_pending = X_pending )
469+ if X_pending is None :
470+ if not hasattr (self , "_X_baseline_and_pending" ) or (
471+ self ._X_baseline_and_pending .shape [- 2 ] == self ._X_baseline .shape [- 2 ]
472+ ):
473+ return
474+ else :
475+ # reset pending points
476+ X_pending = None
477+ self ._init_baseline (
478+ model = self .model ,
479+ X_baseline = self ._X_baseline ,
480+ X_pending = X_pending ,
481+ sampler = self .sampler ,
482+ objective = self .objective ,
483+ posterior_transform = self .posterior_transform ,
484+ cache_root = self ._cache_root ,
485+ )
486+
421487 def compute_best_f (self , obj : Tensor ) -> Tensor :
422488 """Computes the best (feasible) noisy objective value.
423489
0 commit comments