Skip to content

Commit d1d4c5c

Browse files
sdaultonfacebook-github-bot
authored andcommitted
incremental qLogNEI
Summary: This diff adds an incremental qLogNEI, that addresses many cases where the first candidate in the batch has positive EI (and satisfies the constraints) and subsequent arms violate the constraints (often severely). The issue appears to stem from optimizing the joint EI of the new candidate and the pending points w.r.t the current incumbent(s). My hypothesis is that this makes the initialization strategy perform worse and choose bad starting points. Using sequential batch optimization and optimizing the incremental EI of the new arm relative to the pending points (and the current incumbent) avoids the issue by only quanitifying the improvment of the current arm being optimized. TODO: add this for qNEI in a later diff, but that seems low pri since qLogNEI is widely used. Reviewed By: esantorella Differential Revision: D70288526
1 parent 78c04e2 commit d1d4c5c

File tree

5 files changed

+120
-10
lines changed

5 files changed

+120
-10
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,7 @@ def construct_inputs_qLogNEI(
650650
fat: bool = True,
651651
tau_max: float = TAU_MAX,
652652
tau_relu: float = TAU_RELU,
653+
incremental: bool = True,
653654
):
654655
r"""Construct kwargs for the `qLogNoisyExpectedImprovement` constructor.
655656
@@ -684,6 +685,9 @@ def construct_inputs_qLogNEI(
684685
approximations to max.
685686
tau_relu: Temperature parameter controlling the sharpness of the smooth
686687
approximations to ReLU.
688+
incremental: Whether to compute incremental EI over the pending points
689+
or compute EI of the joint batch improvement (including pending
690+
points).
687691
688692
Returns:
689693
A dict mapping kwarg names of the constructor to values.
@@ -705,6 +709,7 @@ def construct_inputs_qLogNEI(
705709
"fat": fat,
706710
"tau_max": tau_max,
707711
"tau_relu": tau_relu,
712+
"incremental": incremental,
708713
}
709714

710715

botorch/acquisition/logei.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ class qLogNoisyExpectedImprovement(
250250
251251
where `(Y, Y_baseline) ~ f((X, X_baseline)), X = (x_1,...,x_q)`.
252252
253+
For optimizing a batch of `q > 1` points using sequential greedy optimization,
254+
the incremental improvement from the latest point is computed and returned by
255+
default. I.e. the pending points are treated X_baseline. Often, the incremental
256+
EI is easier to optimize.
257+
253258
Example:
254259
>>> model = SingleTaskGP(train_X, train_Y)
255260
>>> sampler = SobolQMCNormalSampler(1024)
@@ -273,6 +278,7 @@ def __init__(
273278
tau_max: float = TAU_MAX,
274279
tau_relu: float = TAU_RELU,
275280
marginalize_dim: int | None = None,
281+
incremental: bool = True,
276282
) -> None:
277283
r"""q-Noisy Expected Improvement.
278284
@@ -312,6 +318,9 @@ def __init__(
312318
tau_relu: Temperature parameter controlling the sharpness of the smooth
313319
approximations to ReLU.
314320
marginalize_dim: The dimension to marginalize over.
321+
incremental: Whether to compute incremental EI over the pending points
322+
or compute EI of the joint batch improvement (including pending
323+
points).
315324
316325
TODO: similar to qNEHVI, when we are using sequential greedy candidate
317326
selection, we could incorporate pending points X_baseline and compute
@@ -320,27 +329,34 @@ def __init__(
320329
"""
321330
# TODO: separate out baseline variables initialization and other functions
322331
# in qNEI to avoid duplication of both code and work at runtime.
332+
self.incremental = incremental
333+
323334
super().__init__(
324335
model=model,
325336
sampler=sampler,
326337
objective=objective,
327338
posterior_transform=posterior_transform,
328-
X_pending=X_pending,
339+
# we set X_pending in init_baseline for incremental NEI
340+
X_pending=X_pending if not incremental else None,
329341
constraints=constraints,
330342
eta=eta,
331343
fat=fat,
332344
tau_max=tau_max,
333345
)
334346
self.tau_relu = tau_relu
347+
self.prune_baseline = prune_baseline
348+
self.marginalize_dim = marginalize_dim
349+
if incremental:
350+
self.X_pending = None # required to initialize attribute for optimize_acqf
335351
self._init_baseline(
336352
model=model,
337353
X_baseline=X_baseline,
354+
# This is ignored in incremental=False
355+
X_pending=X_pending,
338356
sampler=sampler,
339357
objective=objective,
340358
posterior_transform=posterior_transform,
341-
prune_baseline=prune_baseline,
342359
cache_root=cache_root,
343-
marginalize_dim=marginalize_dim,
344360
)
345361

346362
def _sample_forward(self, obj: Tensor) -> Tensor:
@@ -364,26 +380,34 @@ def _init_baseline(
364380
self,
365381
model: Model,
366382
X_baseline: Tensor,
383+
X_pending: Tensor | None = None,
367384
sampler: MCSampler | None = None,
368385
objective: MCAcquisitionObjective | None = None,
369386
posterior_transform: PosteriorTransform | None = None,
370-
prune_baseline: bool = False,
371387
cache_root: bool = True,
372-
marginalize_dim: int | None = None,
373388
) -> None:
374389
CachedCholeskyMCSamplerMixin.__init__(
375390
self, model=model, cache_root=cache_root, sampler=sampler
376391
)
377-
if prune_baseline:
392+
if self.prune_baseline:
378393
X_baseline = prune_inferior_points(
379394
model=model,
380395
X=X_baseline,
381396
objective=objective,
382397
posterior_transform=posterior_transform,
383-
marginalize_dim=marginalize_dim,
398+
marginalize_dim=self.marginalize_dim,
384399
constraints=self._constraints,
385400
)
386-
self.register_buffer("X_baseline", X_baseline)
401+
self.register_buffer("_X_baseline", X_baseline)
402+
# full_X_baseline is the set of points that should be considered as the
403+
# incumbent. For incremental EI, this contains the previously evaluated
404+
# points (X_baseline) and pending points (X_pending). For non-incremental
405+
# EI, this contains the previously evaluated points (X_baseline).
406+
if X_pending is not None and self.incremental:
407+
full_X_baseline = torch.cat([X_baseline, X_pending], dim=-2)
408+
else:
409+
full_X_baseline = X_baseline
410+
self.register_buffer("_full_X_baseline", full_X_baseline)
387411
# registering buffers for _get_samples_and_objectives in the next `if` block
388412
self.register_buffer("baseline_samples", None)
389413
self.register_buffer("baseline_obj", None)
@@ -392,7 +416,7 @@ def _init_baseline(
392416
# set baseline samples
393417
with torch.no_grad(): # this is _get_samples_and_objectives(X_baseline)
394418
posterior = self.model.posterior(
395-
X_baseline, posterior_transform=self.posterior_transform
419+
self.X_baseline, posterior_transform=self.posterior_transform
396420
)
397421
# Note: The root decomposition is cached in two different places. It
398422
# may be confusing to have two different caches, but this is not
@@ -404,7 +428,9 @@ def _init_baseline(
404428
# - self._baseline_L allows a root decomposition to be persisted outside
405429
# this method.
406430
self.baseline_samples = self.get_posterior_samples(posterior)
407-
self.baseline_obj = self.objective(self.baseline_samples, X=X_baseline)
431+
self.baseline_obj = self.objective(
432+
self.baseline_samples, X=self.X_baseline
433+
)
408434

409435
# We make a copy here because we will write an attribute `base_samples`
410436
# to `self.base_sampler.base_samples`, and we don't want to mutate
@@ -418,6 +444,46 @@ def _init_baseline(
418444
)
419445
self._baseline_L = self._compute_root_decomposition(posterior=posterior)
420446

447+
@property
448+
def X_baseline(self) -> Tensor:
449+
"""Returns the set of pointsthat should be considered as the incumbent.
450+
451+
For incremental EI, this contains the previously evaluated points
452+
(X_baseline) and pending points (X_pending). For non-incremental
453+
EI, this contains the previously evaluated points (X_baseline).
454+
"""
455+
return self._full_X_baseline
456+
457+
def set_X_pending(self, X_pending: Tensor | None = None) -> None:
458+
r"""Informs the acquisition function about pending design points.
459+
460+
Here pending points are concatenated with X_baseline and incremental
461+
NEI is computed.
462+
463+
Args:
464+
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
465+
been submitted for evaluation but have not yet been evaluated.
466+
"""
467+
if not self.incremental:
468+
return super().set_X_pending(X_pending=X_pending)
469+
if X_pending is None:
470+
if not hasattr(self, "_X_baseline_and_pending") or (
471+
self._X_baseline_and_pending.shape[-2] == self._X_baseline.shape[-2]
472+
):
473+
return
474+
else:
475+
# reset pending points
476+
X_pending = None
477+
self._init_baseline(
478+
model=self.model,
479+
X_baseline=self._X_baseline,
480+
X_pending=X_pending,
481+
sampler=self.sampler,
482+
objective=self.objective,
483+
posterior_transform=self.posterior_transform,
484+
cache_root=self._cache_root,
485+
)
486+
421487
def compute_best_f(self, obj: Tensor) -> Tensor:
422488
"""Computes the best (feasible) noisy objective value.
423489

botorch/acquisition/multi_objective/parego.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
cache_root: bool = True,
3636
tau_relu: float = TAU_RELU,
3737
tau_max: float = TAU_MAX,
38+
incremental: bool = True,
3839
) -> None:
3940
r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function
4041
utilizes qLogNEI to compute the expected improvement over Chebyshev
@@ -88,6 +89,9 @@ def __init__(
8889
approximations to max.
8990
tau_relu: Temperature parameter controlling the sharpness of the smooth
9091
approximations to ReLU.
92+
incremental: Whether to compute incremental EI over the pending points
93+
or compute EI of the joint batch improvement (including pending
94+
points).
9195
"""
9296
MultiObjectiveMCAcquisitionFunction.__init__(
9397
self,
@@ -134,6 +138,7 @@ def __init__(
134138
cache_root=cache_root,
135139
tau_max=tau_max,
136140
tau_relu=tau_relu,
141+
incremental=incremental,
137142
)
138143
# Set these after __init__ calls so that they're not overwritten / deleted.
139144
# These are intended mainly for easier debugging & transparency.

test/acquisition/multi_objective/test_parego.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def base_test_parego(
2626
with_scalarization_weights: bool = False,
2727
with_objective: bool = False,
2828
model: Model | None = None,
29+
incremental: bool = True,
2930
) -> None:
3031
if with_constraints:
3132
assert with_objective, "Objective must be specified if constraints are."
@@ -57,6 +58,7 @@ def base_test_parego(
5758
objective=objective,
5859
constraints=constraints,
5960
prune_baseline=True,
61+
incremental=incremental,
6062
)
6163
self.assertEqual(acqf.Y_baseline.shape, torch.Size([3, 2]))
6264
# Scalarization weights should be set if given and sampled otherwise.
@@ -102,6 +104,9 @@ def test_parego_with_constraints_objective_weights(self) -> None:
102104
with_constraints=True, with_objective=True, with_scalarization_weights=True
103105
)
104106

107+
def test_parego_with_non_incremental_ei(self) -> None:
108+
self.base_test_parego(incremental=False)
109+
105110
def test_parego_with_ensemble_model(self) -> None:
106111
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
107112
models = []

test/acquisition/test_logei.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def test_q_log_noisy_expected_improvement(self):
404404
"sampler": sampler,
405405
"prune_baseline": False,
406406
"cache_root": False,
407+
"incremental": False,
407408
}
408409
# copy for log version
409410
log_acqf = qLogNoisyExpectedImprovement(**kwargs)
@@ -422,6 +423,34 @@ def test_q_log_noisy_expected_improvement(self):
422423
self.assertEqual(log_acqf.X_pending, X2)
423424
self.assertEqual(sum(issubclass(w.category, BotorchWarning) for w in ws), 1)
424425

426+
# test incremental
427+
# Check that adding a pending point is equivalent to adding a point to
428+
# X_baseline
429+
for cache_root in (True, False):
430+
kwargs = {
431+
"model": mm_noisy_pending,
432+
"X_baseline": X_baseline,
433+
"sampler": sampler,
434+
"prune_baseline": False,
435+
"cache_root": cache_root,
436+
"incremental": True,
437+
}
438+
log_acqf = qLogNoisyExpectedImprovement(**kwargs)
439+
log_acqf.set_X_pending(X)
440+
self.assertIsNone(log_acqf.X_pending)
441+
af_val1 = log_acqf(X2)
442+
kwargs = {
443+
"model": mm_noisy_pending,
444+
"X_baseline": torch.cat([X_baseline, X], dim=-2),
445+
"sampler": sampler,
446+
"prune_baseline": False,
447+
"cache_root": cache_root,
448+
"incremental": False,
449+
}
450+
log_acqf = qLogNoisyExpectedImprovement(**kwargs)
451+
af_val2 = log_acqf(X2)
452+
self.assertAllClose(af_val1.item(), af_val2.item())
453+
425454
def test_q_noisy_expected_improvement_batch(self):
426455
for dtype in (torch.float, torch.double):
427456
# the event shape is `b x q x t` = 2 x 3 x 1

0 commit comments

Comments
 (0)