Skip to content

Commit bc4b0c6

Browse files
sdaultonfacebook-github-bot
authored andcommitted
return feasible candidate if there is one and return_best_only=True (#2778)
Summary: Pull Request resolved: #2778 as title. If `return_best_only=True`, we only need one candidate that satisfies the parameter constraints. In that case, we shouldn't error out if we did find a feasible point. Reviewed By: Balandat Differential Revision: D71509592 fbshipit-source-id: 4a1dae0db371452a0e05752c6b0128d0b3cd6df6
1 parent 722baa3 commit bc4b0c6

File tree

4 files changed

+118
-60
lines changed

4 files changed

+118
-60
lines changed

botorch/generation/gen.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,12 @@
2020
import numpy.typing as npt
2121
import torch
2222
from botorch.acquisition import AcquisitionFunction
23-
from botorch.exceptions.errors import (
24-
CandidateGenerationError,
25-
OptimizationGradientError,
26-
)
23+
from botorch.exceptions.errors import OptimizationGradientError
2724
from botorch.exceptions.warnings import OptimizationWarning
2825
from botorch.generation.utils import _remove_fixed_features_from_optimization
2926
from botorch.logging import logger
3027
from botorch.optim.parameter_constraints import (
3128
_arrayify,
32-
evaluate_feasibility,
3329
make_scipy_bounds,
3430
make_scipy_linear_constraints,
3531
make_scipy_nonlinear_inequality_constraints,
@@ -264,23 +260,6 @@ def f(x):
264260
fixed_features=fixed_features,
265261
)
266262

267-
# SLSQP can sometimes fail to produce a feasible candidate. Check for
268-
# feasibility and error out if necessary.
269-
if not (
270-
is_feasible := evaluate_feasibility(
271-
X=candidates,
272-
inequality_constraints=inequality_constraints,
273-
equality_constraints=equality_constraints,
274-
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
275-
)
276-
).all():
277-
raise CandidateGenerationError(
278-
f"The {method} optimizer produced infeasible candidates. "
279-
f"{(~is_feasible).sum().item()} out of {is_feasible.numel()} batches "
280-
"of candidates were infeasible. Please make sure the constraints are "
281-
"satisfiable and relax them if needed. "
282-
)
283-
284263
clamped_candidates = columnwise_clamp(
285264
X=candidates, lower=lower_bounds, upper=upper_bounds, raise_on_violation=True
286265
)

botorch/optim/optimize.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
gen_one_shot_kg_initial_conditions,
3636
TGenInitialConditions,
3737
)
38+
from botorch.optim.parameter_constraints import evaluate_feasibility
3839
from botorch.optim.stopping import ExpMAStoppingCriterion
3940
from torch import Tensor
4041

@@ -354,6 +355,15 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> tuple[Tensor, Tensor
354355
),
355356
)
356357

358+
gen_kwargs = {}
359+
for constraint_name in [
360+
"inequality_constraints",
361+
"equality_constraints",
362+
"nonlinear_inequality_constraints",
363+
]:
364+
if (constraint := getattr(opt_inputs, constraint_name)) is not None:
365+
gen_kwargs[constraint_name] = constraint
366+
357367
def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
358368
batch_candidates_list: list[Tensor] = []
359369
batch_acq_values_list: list[Tensor] = []
@@ -370,15 +380,6 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
370380
upper_bounds = None if bounds[1].isinf().all() else bounds[1]
371381
gen_options = {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}
372382

373-
gen_kwargs = {}
374-
for constraint_name in [
375-
"inequality_constraints",
376-
"equality_constraints",
377-
"nonlinear_inequality_constraints",
378-
]:
379-
if (constraint := getattr(opt_inputs, constraint_name)) is not None:
380-
gen_kwargs[constraint_name] = constraint
381-
382383
for i, batched_ics_ in enumerate(batched_ics):
383384
# optimize using random restart optimization
384385
with warnings.catch_warnings(record=True) as ws:
@@ -471,7 +472,27 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
471472
]
472473
batch_acq_values = torch.cat(acq_values_list, dim=0)
473474

475+
# SLSQP can sometimes fail to produce a feasible candidate. Check for
476+
# feasibility and error out if necessary.
477+
is_feasible = evaluate_feasibility(
478+
X=batch_candidates,
479+
inequality_constraints=gen_kwargs.get("inequality_constraints"),
480+
equality_constraints=gen_kwargs.get("equality_constraints"),
481+
nonlinear_inequality_constraints=gen_kwargs.get(
482+
"nonlinear_inequality_constraints"
483+
),
484+
)
485+
infeasible = ~is_feasible
486+
if (opt_inputs.return_best_only and (not is_feasible.any())) or infeasible.all():
487+
raise CandidateGenerationError(
488+
f"The optimizer produced infeasible candidates. "
489+
f"{(~is_feasible).sum().item()} out of {is_feasible.numel()} batches "
490+
"of candidates were infeasible. Please make sure the constraints are "
491+
"satisfiable and relax them if needed. "
492+
)
474493
if opt_inputs.return_best_only:
494+
# filter for feasible candidates
495+
batch_acq_values[infeasible] = -float("inf")
475496
best = torch.argmax(batch_acq_values.view(-1), dim=0)
476497
batch_candidates = batch_candidates[best]
477498
batch_acq_values = batch_acq_values[best]

test/generation/test_gen.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111

1212
import torch
1313
from botorch.acquisition import qExpectedImprovement, qKnowledgeGradient
14-
from botorch.exceptions.errors import (
15-
CandidateGenerationError,
16-
OptimizationGradientError,
17-
)
14+
from botorch.exceptions.errors import OptimizationGradientError
1815
from botorch.exceptions.warnings import OptimizationWarning
1916
from botorch.fit import fit_gpytorch_mll
2017
from botorch.generation.gen import (
@@ -389,27 +386,6 @@ def test_gen_candidates_scipy_invalid_method(self) -> None:
389386
upper_bounds=1,
390387
)
391388

392-
def test_gen_candidates_scipy_infeasible_candidates(self) -> None:
393-
# Check for error when infeasible candidates are generated.
394-
ics = torch.rand(2, 3, 1, device=self.device)
395-
with mock.patch(
396-
"botorch.generation.gen.minimize_with_timeout",
397-
return_value=OptimizeResult(x=ics.view(-1).cpu().numpy()),
398-
), self.assertRaisesRegex(
399-
CandidateGenerationError, "infeasible candidates. 2 out of 2"
400-
):
401-
gen_candidates_scipy(
402-
initial_conditions=ics,
403-
acquisition_function=MockAcquisitionFunction(),
404-
inequality_constraints=[
405-
( # X[..., 0] >= 2.0, which is infeasible.
406-
torch.tensor([0], device=self.device),
407-
torch.tensor([1.0], device=self.device),
408-
2.0,
409-
)
410-
],
411-
)
412-
413389

414390
class TestRandomRestartOptimization(TestBaseCandidateGeneration):
415391
def test_random_restart_optimization(self):

test/optim/test_optimize.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def test_optimize_acqf_sequential(
341341
gcs_return_vals = [
342342
(
343343
torch.tensor(
344-
[[[1.1, 2.1, 3.1]]], device=self.device, dtype=dtype
344+
[[[1.1, 2.1, 4.0]]], device=self.device, dtype=dtype
345345
),
346346
torch.tensor([i], device=self.device, dtype=dtype),
347347
)
@@ -357,11 +357,20 @@ def test_optimize_acqf_sequential(
357357
if mock_gen_candidates is mock_gen_candidates_scipy:
358358
# x[2] * 4 >= 5
359359
inequality_constraints = [
360-
(torch.tensor([2]), torch.tensor([4]), torch.tensor(5))
360+
(
361+
torch.tensor([2], dtype=torch.long, device=self.device),
362+
torch.tensor([4], device=self.device),
363+
torch.tensor(5, device=self.device),
364+
)
361365
]
362366
equality_constraints = [
363-
(torch.tensor([0, 1]), torch.ones(2), torch.tensor(4.0))
367+
(
368+
torch.tensor([0, 1], dtype=torch.long, device=self.device),
369+
torch.ones(2, device=self.device),
370+
torch.tensor(16.0, device=self.device),
371+
)
364372
]
373+
equality_constraints = None
365374
# gen_candidates_torch does not support constraints
366375
else:
367376
inequality_constraints = None
@@ -1183,6 +1192,77 @@ def __call__(self, x, f):
11831192
self.assertEqual(f_obj(x2), 2.0)
11841193
self.assertEqual(f_np_wrapper.call_count, 2)
11851194

1195+
def _test_optimize_acqf_infeasible_candidates(
1196+
self, mock_gen_batch_initial_conditions, q, num_restarts, ics
1197+
):
1198+
mock_acq_function = MockAcquisitionFunction()
1199+
mock_gen_batch_initial_conditions.side_effect = [ics for _ in range(2)]
1200+
with mock.patch(
1201+
"botorch.generation.gen.minimize_with_timeout",
1202+
return_value=OptimizeResult(x=ics.view(-1).cpu().numpy()),
1203+
):
1204+
candidates, _ = optimize_acqf(
1205+
acq_function=mock_acq_function,
1206+
bounds=torch.tensor(
1207+
[[0.0], [3.0]], dtype=ics.dtype, device=self.device
1208+
),
1209+
q=q,
1210+
num_restarts=num_restarts,
1211+
raw_samples=1,
1212+
inequality_constraints=[
1213+
( # X[..., 0] >= 2.0, which is infeasible.
1214+
torch.tensor([0], device=self.device),
1215+
torch.tensor([1.0], dtype=ics.dtype, device=self.device),
1216+
2.0,
1217+
)
1218+
],
1219+
sequential=False,
1220+
gen_candidates=gen_candidates_scipy,
1221+
initial_conditions=ics,
1222+
acquisition_function=MockAcquisitionFunction(),
1223+
)
1224+
return candidates
1225+
1226+
@mock.patch("botorch.optim.optimize.gen_batch_initial_conditions")
1227+
def test_optimize_acqf_all_infeasible_candidates(
1228+
self, mock_gen_batch_initial_conditions
1229+
) -> None:
1230+
# Check for error when all batches of candidates are infeasible w.r.t
1231+
# parameter constraints.
1232+
q = 3
1233+
num_restarts = 2
1234+
for dtype in (torch.float, torch.double):
1235+
with self.assertRaisesRegex(
1236+
CandidateGenerationError, "infeasible candidates. 2 out of 2"
1237+
):
1238+
self._test_optimize_acqf_infeasible_candidates(
1239+
mock_gen_batch_initial_conditions=mock_gen_batch_initial_conditions,
1240+
q=q,
1241+
num_restarts=num_restarts,
1242+
ics=torch.rand(num_restarts, q, 1, dtype=dtype, device=self.device),
1243+
)
1244+
1245+
@mock.patch("botorch.optim.optimize.gen_batch_initial_conditions")
1246+
def test_optimize_acqf_some_infeasible_candidates(
1247+
self, mock_gen_batch_initial_conditions
1248+
) -> None:
1249+
# Check that no error is raised when the first batch of candidates
1250+
# contains points that are infeasible w.r.t parameter constraints, but
1251+
# second batch is feasible
1252+
q = 3
1253+
num_restarts = 2
1254+
1255+
for dtype in (torch.float, torch.double):
1256+
ics = torch.rand(num_restarts, q, 1, dtype=dtype, device=self.device)
1257+
ics[1, ..., 0] = 3.0 # make second batch feasible
1258+
candidates = self._test_optimize_acqf_infeasible_candidates(
1259+
mock_gen_batch_initial_conditions=mock_gen_batch_initial_conditions,
1260+
q=q,
1261+
num_restarts=num_restarts,
1262+
ics=ics,
1263+
)
1264+
self.assertTrue(torch.equal(candidates, ics[1]))
1265+
11861266

11871267
class TestAllOptimizers(BotorchTestCase):
11881268
def test_raises_with_negative_fixed_features(self) -> None:
@@ -1708,7 +1788,9 @@ def test_optimize_acqf_one_shot_large_q(self):
17081788
def test_optimize_acqf_mixed_ff_with_constraint(self):
17091789
mock_acq_function = MockAcquisitionFunction()
17101790
bounds = torch.stack([torch.zeros(3), 4 * torch.ones(3)])
1711-
ineq_constraints = [(torch.zeros(1), torch.ones(1), 1)] # x[0] >= 1
1791+
ineq_constraints = [
1792+
(torch.zeros(1, dtype=torch.long), torch.ones(1), 1)
1793+
] # x[0] >= 1
17121794
with self.assertWarnsRegex(
17131795
OptimizationWarning,
17141796
"Candidate generation failed for 1 combinations of `fixed_features`. "

0 commit comments

Comments
 (0)