Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ def optimize_acqf_discrete(
choices: Tensor,
max_batch_size: int = 2048,
unique: bool = True,
X_avoid: Tensor | None = None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
) -> tuple[Tensor, Tensor]:
r"""Optimize over a discrete set of points using batch evaluation.

Expand All @@ -1017,6 +1019,12 @@ def optimize_acqf_discrete(
a large training set.
unique: If True return unique choices, o/w choices may be repeated
(only relevant if `q > 1`).
X_avoid: An `n x d` tensor of candidates that we aren't allowed to pick.
These will be removed from the set of choices.
inequality constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
Infeasible points will be removed from the set of choices.

Returns:
A two-element tuple containing
Expand All @@ -1029,8 +1037,31 @@ def optimize_acqf_discrete(
"Discrete optimization is not supported for"
"one-shot acquisition functions."
)
if choices.numel() == 0:
raise InputDataError("`choices` must be non-emtpy.")
if X_avoid is not None and unique:
choices = _filter_invalid(X=choices, X_avoid=X_avoid)
if inequality_constraints is not None:
choices = _filter_infeasible(
X=choices, inequality_constraints=inequality_constraints
)
len_choices = len(choices)
if len_choices == 0:
message = "`choices` must be non-empty."
if X_avoid is not None or inequality_constraints is not None:
message += (
" No feasible points remain after removing `X_avoid` and "
"filtering out infeasible points."
)
raise InputDataError(message)
elif len_choices < q and unique:
warnings.warn(
(
f"Requested {q=} candidates from fully discrete search "
f"space, but only {len_choices} possible choices remain. "
),
OptimizationWarning,
stacklevel=2,
)
q = len_choices
choices_batched = choices.unsqueeze(-2)
if q > 1:
candidate_list, acq_value_list = [], []
Expand Down Expand Up @@ -1081,7 +1112,7 @@ def _generate_neighbors(
discrete_choices: list[Tensor],
X_avoid: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]],
):
) -> Tensor:
# generate all 1D perturbations
npts = sum([len(c) for c in discrete_choices])
X_loc = x.repeat(npts, 1)
Expand All @@ -1097,15 +1128,15 @@ def _generate_neighbors(

def _filter_infeasible(
X: Tensor, inequality_constraints: list[tuple[Tensor, Tensor, float]]
):
) -> Tensor:
"""Remove all points from `X` that don't satisfy the constraints."""
is_feasible = torch.ones(X.shape[0], dtype=torch.bool, device=X.device)
for inds, weights, bound in inequality_constraints:
is_feasible &= (X[..., inds] * weights).sum(dim=-1) >= bound
return X[is_feasible]


def _filter_invalid(X: Tensor, X_avoid: Tensor):
def _filter_invalid(X: Tensor, X_avoid: Tensor) -> Tensor:
"""Remove all occurences of `X_avoid` from `X`."""
return X[~(X == X_avoid.unsqueeze(-2)).all(dim=-1).any(dim=-2)]

Expand Down
44 changes: 41 additions & 3 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
qHypervolumeKnowledgeGradient,
)
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.gen import gen_candidates_scipy, gen_candidates_torch
from botorch.models import SingleTaskGP
from botorch.models.model_list_gp_regression import ModelListGP
Expand Down Expand Up @@ -1556,7 +1557,7 @@ def test_optimize_acqf_discrete(self):
mock_acq_function = SquaredAcquisitionFunction()
mock_acq_function.set_X_pending(None)
# ensure proper raising of errors if no choices
with self.assertRaisesRegex(InputDataError, "`choices` must be non-emtpy."):
with self.assertRaisesRegex(InputDataError, "`choices` must be non-empty."):
optimize_acqf_discrete(
acq_function=mock_acq_function,
q=q,
Expand Down Expand Up @@ -1613,14 +1614,51 @@ def test_optimize_acqf_discrete(self):
self.assertAllClose(acq_value, expected_acq_value)
self.assertAllClose(candidates, expected_candidates)

with self.assertRaises(UnsupportedError):
acqf = MockOneShotAcquisitionFunction()
acqf = MockOneShotAcquisitionFunction()
with self.assertRaisesRegex(UnsupportedError, "one-shot acquisition"):
optimize_acqf_discrete(
acq_function=acqf,
q=1,
choices=torch.tensor([[0.5], [0.2]]),
)

def test_optimize_acqf_discrete_X_avoid_and_constraints(self):
# Check that choices are filtered correctly using X_avoid and constraints.
tkwargs: dict[str, Any] = {"device": self.device, "dtype": torch.double}
mock_acq_function = SquaredAcquisitionFunction()
choices = torch.rand(2, 2, **tkwargs)
with self.assertRaisesRegex(InputDataError, "No feasible points"):
optimize_acqf_discrete(
acq_function=mock_acq_function,
q=1,
choices=choices,
X_avoid=choices,
)
with self.assertWarnsRegex(OptimizationWarning, "Requested q=2 candidates"):
candidates, _ = optimize_acqf_discrete(
acq_function=mock_acq_function,
q=2,
choices=choices,
X_avoid=choices[:1],
)
self.assertAllClose(candidates, choices[1:])
constraints = [
( # X[..., 0] >= 1.0
torch.tensor([0], dtype=torch.long, device=self.device),
torch.tensor([1.0], **tkwargs),
1.0,
)
]
choices[0, 0] = 1.0
with self.assertWarnsRegex(OptimizationWarning, "Requested q=2 candidates"):
candidates, _ = optimize_acqf_discrete(
acq_function=mock_acq_function,
q=2,
choices=choices,
inequality_constraints=constraints,
)
self.assertAllClose(candidates, choices[:1])

def test_optimize_acqf_discrete_local_search(self):
for q, dtype in itertools.product((1, 2), (torch.float, torch.double)):
tkwargs = {"device": self.device, "dtype": dtype}
Expand Down
Loading