Skip to content

Commit bfd1d9e

Browse files
TobyBoynemeta-codesync[bot]
authored andcommitted
Improve best feasible objective (#3011)
Summary: ## Motivation See #3009. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #3011 Test Plan: After feedback on the initial draft of the PR, I will implement tests similar to the problem structure in #3009, confirming that the behaviour of the modified code behaves as desired. ## Related PRs None Reviewed By: mpolson64 Differential Revision: D85342416 Pulled By: Balandat fbshipit-source-id: 3b303cea7e4247d065fa5ca48091a6f222890450
1 parent 82536af commit bfd1d9e

File tree

2 files changed

+70
-21
lines changed

2 files changed

+70
-21
lines changed

botorch/acquisition/utils.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from __future__ import annotations
1212

1313
import math
14+
import warnings
1415
from collections.abc import Callable
1516

1617
import torch
@@ -24,12 +25,16 @@
2425
DeprecationError,
2526
UnsupportedError,
2627
)
28+
from botorch.exceptions.warnings import BotorchWarning
2729
from botorch.models.fully_bayesian import MCMC_DIM
2830
from botorch.models.model import Model
2931
from botorch.sampling.base import MCSampler
3032
from botorch.sampling.get_sampler import get_sampler
3133
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
32-
from botorch.utils.objective import compute_feasibility_indicator
34+
from botorch.utils.objective import (
35+
compute_feasibility_indicator,
36+
compute_smoothed_feasibility_indicator,
37+
)
3338
from botorch.utils.sampling import optimize_posterior_samples
3439
from botorch.utils.transforms import is_ensemble, normalize_indices
3540
from gpytorch.models import GP
@@ -150,6 +155,13 @@ def compute_best_feasible_objective(
150155
raise ValueError(
151156
"Must specify `X_baseline` when no feasible observation exists."
152157
)
158+
warnings.warn(
159+
"When all training points are infeasible, it is better to use "
160+
"q(Log)ProbabilityOfFeasibility.",
161+
BotorchWarning,
162+
stacklevel=2,
163+
)
164+
153165
infeasible_value = _estimate_objective_lower_bound(
154166
model=model,
155167
objective=objective,
@@ -171,8 +183,9 @@ def _estimate_objective_lower_bound(
171183
posterior_transform: PosteriorTransform | None,
172184
X: Tensor,
173185
) -> Tensor:
174-
"""Estimates a lower bound on the objective values by evaluating the model at convex
175-
combinations of `X`, returning the 6-sigma lower bound of the computed statistics.
186+
"""Estimates a lower bound on the objective values by evaluating the at uniformly
187+
random points in the bounding box of `X`, returning the 6-sigma lower bound of the
188+
computed statistics.
176189
177190
Args:
178191
model: A fitted model.
@@ -183,19 +196,21 @@ def _estimate_objective_lower_bound(
183196
Returns:
184197
A `m`-dimensional Tensor of lower bounds of the objectives.
185198
"""
186-
convex_weights = torch.rand(
187-
32,
188-
X.shape[-2],
189-
dtype=X.dtype,
190-
device=X.device,
199+
# we do not have access to `bounds` here, so we infer the bounding box
200+
# from data, expanding by 10% in each direction
201+
X_lb = X.min(dim=-2, keepdim=True).values
202+
X_ub = X.max(dim=-2, keepdim=True).values
203+
X_range = X_ub - X_lb
204+
X_padding = 0.1 * X_range
205+
uniform_samples = torch.rand(
206+
*X.shape[:-2], 32, X.shape[-1], dtype=X.dtype, device=X.device
191207
)
192-
weights_sum = convex_weights.sum(dim=0, keepdim=True)
193-
convex_weights = convex_weights / weights_sum
208+
X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding)
194209
# infeasible cost M is such that -M < min_x f(x), thus
195210
# 0 < min_x f(x) - (-M), so we should take -M as a lower
196211
# bound on the best feasible objective
197212
return -get_infeasible_cost(
198-
X=convex_weights @ X,
213+
X=X_samples,
199214
model=model,
200215
objective=objective,
201216
posterior_transform=posterior_transform,
@@ -235,7 +250,19 @@ def objective(Y: Tensor, X: Tensor | None = None):
235250
return Y.squeeze(-1)
236251

237252
posterior = model.posterior(X, posterior_transform=posterior_transform)
238-
lb = objective(posterior.mean - 6 * posterior.variance.clamp_min(0).sqrt(), X=X)
253+
# We check both the upper and lower bound of the posterior, since the objective
254+
# may be increasing or decreasing. For objectives that are neither (eg. absolute
255+
# distance from a target), this should still provide a good bound.
256+
six_stdv = 6 * posterior.variance.clamp_min(0).sqrt()
257+
lb = torch.stack(
258+
[
259+
objective(posterior.mean - six_stdv, X=X),
260+
objective(posterior.mean + six_stdv, X=X),
261+
],
262+
dim=0,
263+
)
264+
lb = lb.min(dim=0).values
265+
239266
if lb.ndim < posterior.mean.ndim:
240267
lb = lb.unsqueeze(-1)
241268
# Take outcome-wise min. Looping in to handle batched models.
@@ -311,6 +338,7 @@ def _prune_inferior_shared_processing(
311338
samples=samples,
312339
marginalize_dim=marginalize_dim,
313340
)
341+
314342
return max_points, obj_vals, infeas
315343

316344

@@ -374,7 +402,24 @@ def prune_inferior_points(
374402
sampler=sampler,
375403
marginalize_dim=marginalize_dim,
376404
)
377-
if infeas.any():
405+
if infeas.all():
406+
# if no points are feasible, keep the point closest to being feasible
407+
with torch.no_grad():
408+
posterior = model.posterior(X=X, posterior_transform=posterior_transform)
409+
if sampler is None:
410+
sampler = get_sampler(
411+
posterior=posterior, sample_shape=torch.Size([num_samples])
412+
)
413+
samples = sampler(posterior)
414+
# use the probability of feasibility as the objective for computing best points
415+
obj_vals = compute_smoothed_feasibility_indicator(
416+
constraints=constraints,
417+
samples=samples,
418+
eta=1e-3,
419+
log=True,
420+
)
421+
422+
elif infeas.any():
378423
# set infeasible points to worse than worst objective across all samples
379424
# Use clone() here to avoid deprecated `index_put_` on an expanded tensor
380425
obj_vals = obj_vals.clone()

test/acquisition/test_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DeprecationError,
3333
UnsupportedError,
3434
)
35+
from botorch.exceptions.warnings import BotorchWarning
3536
from botorch.models import SingleTaskGP
3637
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
3738
from gpytorch.distributions import MultivariateNormal
@@ -154,14 +155,17 @@ def test_compute_best_feasible_objective(self):
154155
def objective(Y, X):
155156
return Y.squeeze(-1) - 5.0
156157

157-
best_f = compute_best_feasible_objective(
158-
samples=samples,
159-
obj=obj,
160-
constraints=[lambda X: torch.ones_like(X[..., 0])],
161-
model=mm,
162-
X_baseline=X,
163-
objective=objective,
164-
)
158+
with self.assertWarnsRegex(
159+
BotorchWarning, "ProbabilityOfFeasibility"
160+
):
161+
best_f = compute_best_feasible_objective(
162+
samples=samples,
163+
obj=obj,
164+
constraints=[lambda X: torch.ones_like(X[..., 0])],
165+
model=mm,
166+
X_baseline=X,
167+
objective=objective,
168+
)
165169
expected_best_f = torch.full(
166170
sample_shape + batch_shape,
167171
-get_infeasible_cost(X=X, model=mm, objective=objective).item(),

0 commit comments

Comments
 (0)