Skip to content

Commit ccf278a

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Support inequality constraints & X_avoid support in optimize_acqf_discrete (#2593)
Summary: Pull Request resolved: #2593 These were previously handled in `Acquisition.optimize` in Ax. Pushing it down to BoTorch makes the functionality more broadly useful and helps simplify `Acquisition.optimize` (next diff). Reviewed By: Balandat Differential Revision: D64841997 fbshipit-source-id: 52ff836da1218216d6a378910be12e154206e6dc
1 parent 563cd95 commit ccf278a

File tree

2 files changed

+77
-8
lines changed

2 files changed

+77
-8
lines changed

botorch/optim/optimize.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,8 @@ def optimize_acqf_discrete(
10001000
choices: Tensor,
10011001
max_batch_size: int = 2048,
10021002
unique: bool = True,
1003+
X_avoid: Tensor | None = None,
1004+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
10031005
) -> tuple[Tensor, Tensor]:
10041006
r"""Optimize over a discrete set of points using batch evaluation.
10051007
@@ -1017,6 +1019,12 @@ def optimize_acqf_discrete(
10171019
a large training set.
10181020
unique: If True return unique choices, o/w choices may be repeated
10191021
(only relevant if `q > 1`).
1022+
X_avoid: An `n x d` tensor of candidates that we aren't allowed to pick.
1023+
These will be removed from the set of choices.
1024+
inequality constraints: A list of tuples (indices, coefficients, rhs),
1025+
with each tuple encoding an inequality constraint of the form
1026+
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
1027+
Infeasible points will be removed from the set of choices.
10201028
10211029
Returns:
10221030
A two-element tuple containing
@@ -1029,8 +1037,31 @@ def optimize_acqf_discrete(
10291037
"Discrete optimization is not supported for"
10301038
"one-shot acquisition functions."
10311039
)
1032-
if choices.numel() == 0:
1033-
raise InputDataError("`choices` must be non-emtpy.")
1040+
if X_avoid is not None and unique:
1041+
choices = _filter_invalid(X=choices, X_avoid=X_avoid)
1042+
if inequality_constraints is not None:
1043+
choices = _filter_infeasible(
1044+
X=choices, inequality_constraints=inequality_constraints
1045+
)
1046+
len_choices = len(choices)
1047+
if len_choices == 0:
1048+
message = "`choices` must be non-empty."
1049+
if X_avoid is not None or inequality_constraints is not None:
1050+
message += (
1051+
" No feasible points remain after removing `X_avoid` and "
1052+
"filtering out infeasible points."
1053+
)
1054+
raise InputDataError(message)
1055+
elif len_choices < q and unique:
1056+
warnings.warn(
1057+
(
1058+
f"Requested {q=} candidates from fully discrete search "
1059+
f"space, but only {len_choices} possible choices remain. "
1060+
),
1061+
OptimizationWarning,
1062+
stacklevel=2,
1063+
)
1064+
q = len_choices
10341065
choices_batched = choices.unsqueeze(-2)
10351066
if q > 1:
10361067
candidate_list, acq_value_list = [], []
@@ -1081,7 +1112,7 @@ def _generate_neighbors(
10811112
discrete_choices: list[Tensor],
10821113
X_avoid: Tensor,
10831114
inequality_constraints: list[tuple[Tensor, Tensor, float]],
1084-
):
1115+
) -> Tensor:
10851116
# generate all 1D perturbations
10861117
npts = sum([len(c) for c in discrete_choices])
10871118
X_loc = x.repeat(npts, 1)
@@ -1097,15 +1128,15 @@ def _generate_neighbors(
10971128

10981129
def _filter_infeasible(
10991130
X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]]
1100-
):
1131+
) -> Tensor:
11011132
"""Remove all points from `X` that don't satisfy the constraints."""
11021133
is_feasible = torch.ones(X.shape[0], dtype=torch.bool, device=X.device)
11031134
for inds, weights, bound in inequality_constraints:
11041135
is_feasible &= (X[..., inds] * weights).sum(dim=-1) >= bound
11051136
return X[is_feasible]
11061137

11071138

1108-
def _filter_invalid(X: Tensor, X_avoid: Tensor):
1139+
def _filter_invalid(X: Tensor, X_avoid: Tensor) -> Tensor:
11091140
"""Remove all occurences of `X_avoid` from `X`."""
11101141
return X[~(X == X_avoid.unsqueeze(-2)).all(dim=-1).any(dim=-2)]
11111142

test/optim/test_optimize.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
qHypervolumeKnowledgeGradient,
2323
)
2424
from botorch.exceptions import InputDataError, UnsupportedError
25+
from botorch.exceptions.warnings import OptimizationWarning
2526
from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch
2627
from botorch.models import SingleTaskGP
2728
from botorch.models.model_list_gp_regression import ModelListGP
@@ -1556,7 +1557,7 @@ def test_optimize_acqf_discrete(self):
15561557
mock_acq_function = SquaredAcquisitionFunction()
15571558
mock_acq_function.set_X_pending(None)
15581559
# ensure proper raising of errors if no choices
1559-
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
1560+
with self.assertRaisesRegex(InputDataError, "`choices` must be non-empty."):
15601561
optimize_acqf_discrete(
15611562
acq_function=mock_acq_function,
15621563
q=q,
@@ -1613,14 +1614,51 @@ def test_optimize_acqf_discrete(self):
16131614
self.assertAllClose(acq_value, expected_acq_value)
16141615
self.assertAllClose(candidates, expected_candidates)
16151616

1616-
with self.assertRaises(UnsupportedError):
1617-
acqf = MockOneShotAcquisitionFunction()
1617+
acqf = MockOneShotAcquisitionFunction()
1618+
with self.assertRaisesRegex(UnsupportedError, "one-shot acquisition"):
16181619
optimize_acqf_discrete(
16191620
acq_function=acqf,
16201621
q=1,
16211622
choices=torch.tensor([[0.5], [0.2]]),
16221623
)
16231624

1625+
def test_optimize_acqf_discrete_X_avoid_and_constraints(self):
1626+
# Check that choices are filtered correctly using X_avoid and constraints.
1627+
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
1628+
mock_acq_function = SquaredAcquisitionFunction()
1629+
choices = torch.rand(2, 2, **tkwargs)
1630+
with self.assertRaisesRegex(InputDataError, "No feasible points"):
1631+
optimize_acqf_discrete(
1632+
acq_function=mock_acq_function,
1633+
q=1,
1634+
choices=choices,
1635+
X_avoid=choices,
1636+
)
1637+
with self.assertWarnsRegex(OptimizationWarning, "Requested q=2 candidates"):
1638+
candidates, _ = optimize_acqf_discrete(
1639+
acq_function=mock_acq_function,
1640+
q=2,
1641+
choices=choices,
1642+
X_avoid=choices[:1],
1643+
)
1644+
self.assertAllClose(candidates, choices[1:])
1645+
constraints = [
1646+
( # X[..., 0] >= 1.0
1647+
torch.tensor([0], dtype=torch.long, device=self.device),
1648+
torch.tensor([1.0], **tkwargs),
1649+
1.0,
1650+
)
1651+
]
1652+
choices[0, 0] = 1.0
1653+
with self.assertWarnsRegex(OptimizationWarning, "Requested q=2 candidates"):
1654+
candidates, _ = optimize_acqf_discrete(
1655+
acq_function=mock_acq_function,
1656+
q=2,
1657+
choices=choices,
1658+
inequality_constraints=constraints,
1659+
)
1660+
self.assertAllClose(candidates, choices[:1])
1661+
16241662
def test_optimize_acqf_discrete_local_search(self):
16251663
for q, dtype in itertools.product((1, 2), (torch.float, torch.double)):
16261664
tkwargs = {"device": self.device, "dtype": dtype}

0 commit comments

Comments
 (0)