1111from __future__ import annotations
1212
1313import math
14+ import warnings
1415from collections .abc import Callable
1516
1617import torch
2425 DeprecationError ,
2526 UnsupportedError ,
2627)
28+ from botorch .exceptions .warnings import BotorchWarning
2729from botorch .models .fully_bayesian import MCMC_DIM
2830from botorch .models .model import Model
2931from botorch .sampling .base import MCSampler
3032from botorch .sampling .get_sampler import get_sampler
3133from botorch .sampling .pathwise .posterior_samplers import get_matheron_path_model
32- from botorch .utils .objective import compute_feasibility_indicator
34+ from botorch .utils .objective import (
35+ compute_feasibility_indicator ,
36+ compute_smoothed_feasibility_indicator ,
37+ )
3338from botorch .utils .sampling import optimize_posterior_samples
3439from botorch .utils .transforms import is_ensemble , normalize_indices
3540from gpytorch .models import GP
@@ -150,6 +155,13 @@ def compute_best_feasible_objective(
150155 raise ValueError (
151156 "Must specify `X_baseline` when no feasible observation exists."
152157 )
158+ warnings .warn (
159+ "When all training points are infeasible, it is better to use "
160+ "q(Log)ProbabilityOfFeasibility." ,
161+ BotorchWarning ,
162+ stacklevel = 2 ,
163+ )
164+
153165 infeasible_value = _estimate_objective_lower_bound (
154166 model = model ,
155167 objective = objective ,
@@ -171,8 +183,9 @@ def _estimate_objective_lower_bound(
171183 posterior_transform : PosteriorTransform | None ,
172184 X : Tensor ,
173185) -> Tensor :
174- """Estimates a lower bound on the objective values by evaluating the model at convex
175- combinations of `X`, returning the 6-sigma lower bound of the computed statistics.
186+ """Estimates a lower bound on the objective values by evaluating the at uniformly
187+ random points in the bounding box of `X`, returning the 6-sigma lower bound of the
188+ computed statistics.
176189
177190 Args:
178191 model: A fitted model.
@@ -183,19 +196,21 @@ def _estimate_objective_lower_bound(
183196 Returns:
184197 A `m`-dimensional Tensor of lower bounds of the objectives.
185198 """
186- convex_weights = torch .rand (
187- 32 ,
188- X .shape [- 2 ],
189- dtype = X .dtype ,
190- device = X .device ,
199+ # we do not have access to `bounds` here, so we infer the bounding box
200+ # from data, expanding by 10% in each direction
201+ X_lb = X .min (dim = - 2 , keepdim = True ).values
202+ X_ub = X .max (dim = - 2 , keepdim = True ).values
203+ X_range = X_ub - X_lb
204+ X_padding = 0.1 * X_range
205+ uniform_samples = torch .rand (
206+ * X .shape [:- 2 ], 32 , X .shape [- 1 ], dtype = X .dtype , device = X .device
191207 )
192- weights_sum = convex_weights .sum (dim = 0 , keepdim = True )
193- convex_weights = convex_weights / weights_sum
208+ X_samples = X_lb - X_padding + uniform_samples * (X_range + 2 * X_padding )
194209 # infeasible cost M is such that -M < min_x f(x), thus
195210 # 0 < min_x f(x) - (-M), so we should take -M as a lower
196211 # bound on the best feasible objective
197212 return - get_infeasible_cost (
198- X = convex_weights @ X ,
213+ X = X_samples ,
199214 model = model ,
200215 objective = objective ,
201216 posterior_transform = posterior_transform ,
@@ -235,7 +250,19 @@ def objective(Y: Tensor, X: Tensor | None = None):
235250 return Y .squeeze (- 1 )
236251
237252 posterior = model .posterior (X , posterior_transform = posterior_transform )
238- lb = objective (posterior .mean - 6 * posterior .variance .clamp_min (0 ).sqrt (), X = X )
253+ # We check both the upper and lower bound of the posterior, since the objective
254+ # may be increasing or decreasing. For objectives that are neither (eg. absolute
255+ # distance from a target), this should still provide a good bound.
256+ six_stdv = 6 * posterior .variance .clamp_min (0 ).sqrt ()
257+ lb = torch .stack (
258+ [
259+ objective (posterior .mean - six_stdv , X = X ),
260+ objective (posterior .mean + six_stdv , X = X ),
261+ ],
262+ dim = 0 ,
263+ )
264+ lb = lb .min (dim = 0 ).values
265+
239266 if lb .ndim < posterior .mean .ndim :
240267 lb = lb .unsqueeze (- 1 )
241268 # Take outcome-wise min. Looping in to handle batched models.
@@ -311,6 +338,7 @@ def _prune_inferior_shared_processing(
311338 samples = samples ,
312339 marginalize_dim = marginalize_dim ,
313340 )
341+
314342 return max_points , obj_vals , infeas
315343
316344
@@ -374,7 +402,24 @@ def prune_inferior_points(
374402 sampler = sampler ,
375403 marginalize_dim = marginalize_dim ,
376404 )
377- if infeas .any ():
405+ if infeas .all ():
406+ # if no points are feasible, keep the point closest to being feasible
407+ with torch .no_grad ():
408+ posterior = model .posterior (X = X , posterior_transform = posterior_transform )
409+ if sampler is None :
410+ sampler = get_sampler (
411+ posterior = posterior , sample_shape = torch .Size ([num_samples ])
412+ )
413+ samples = sampler (posterior )
414+ # use the probability of feasibility as the objective for computing best points
415+ obj_vals = compute_smoothed_feasibility_indicator (
416+ constraints = constraints ,
417+ samples = samples ,
418+ eta = 1e-3 ,
419+ log = True ,
420+ )
421+
422+ elif infeas .any ():
378423 # set infeasible points to worse than worst objective across all samples
379424 # Use clone() here to avoid deprecated `index_put_` on an expanded tensor
380425 obj_vals = obj_vals .clone ()
0 commit comments