Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
gen_one_shot_kg_initial_conditions,
TGenInitialConditions,
)
from botorch.optim.parameter_constraints import evaluate_feasibility
from botorch.optim.parameter_constraints import (
evaluate_feasibility,
project_to_feasible_space_via_slsqp,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from torch import Tensor

Expand Down Expand Up @@ -513,15 +516,47 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:

# SLSQP can sometimes fail to produce a feasible candidate. Check for
# feasibility and error out if necessary.
# if there are equality constraints, project the candidate to the feasible set
equality_constraints = gen_kwargs.get("equality_constraints")
inequality_constraints = gen_kwargs.get("inequality_constraints")
nonlinear_inequality_constraints = gen_kwargs.get(
"nonlinear_inequality_constraints"
)
is_feasible = evaluate_feasibility(
X=batch_candidates,
inequality_constraints=gen_kwargs.get("inequality_constraints"),
equality_constraints=gen_kwargs.get("equality_constraints"),
nonlinear_inequality_constraints=gen_kwargs.get(
"nonlinear_inequality_constraints"
),
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
)
infeasible = ~is_feasible
if nonlinear_inequality_constraints is None and infeasible.any():
projected_candidates = project_to_feasible_space_via_slsqp(
X=batch_candidates[infeasible],
bounds=opt_inputs.bounds,
equality_constraints=equality_constraints,
inequality_constraints=inequality_constraints,
)
if opt_inputs.post_processing_func is not None:
projected_candidates = opt_inputs.post_processing_func(projected_candidates)
batch_candidates[infeasible] = projected_candidates
# recompute AF values for projected points
with torch.no_grad():
batch_acq_values[infeasible] = torch.cat(
[
opt_inputs.acq_function(cand)
for cand in projected_candidates.split(batch_limit, dim=0)
],
dim=0,
)
# re-evaluate feasibility
is_feasible = evaluate_feasibility(
X=batch_candidates,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
)
infeasible = ~is_feasible

if (opt_inputs.return_best_only and (not is_feasible.any())) or infeasible.all():
raise CandidateGenerationError(
f"The optimizer produced infeasible candidates. "
Expand Down
106 changes: 102 additions & 4 deletions botorch/optim/parameter_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,32 @@
import numpy.typing as npt
import torch
from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError
from scipy.optimize import Bounds
from botorch.optim.utils import columnwise_clamp
from scipy.optimize import Bounds, minimize
from torch import Tensor


ScipyConstraintDict = dict[
str, Union[str, Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]
]
CONST_TOL = 1e-6


def get_constraint_tolerance(dtype: torch.dtype) -> float:
r"""Get the constraint tolerance for a given dtype.

Args:
dtype: The dtype to use.

Returns:
The constraint tolerance for the given dtype.
"""
if dtype == torch.double:
return 1e-8
elif dtype == torch.float:
return 1e-6
elif dtype == torch.half:
return 1e-4
raise ValueError(f"Unsupported dtype {dtype}.")


def make_scipy_bounds(
Expand Down Expand Up @@ -513,7 +531,7 @@ def nonlinear_constraint_is_feasible(
nonlinear_inequality_constraint: Callable,
is_intrapoint: bool,
x: Tensor,
tolerance: float = CONST_TOL,
tolerance: float | None = None,
) -> Tensor:
"""Checks if a nonlinear inequality constraint is fulfilled (within tolerance).

Expand All @@ -533,6 +551,8 @@ def nonlinear_constraint_is_feasible(
A boolean tensor of shape (batch) indicating if the constraint is
satified by the corresponding batch of `x`.
"""
if tolerance is None:
tolerance = get_constraint_tolerance(dtype=x.dtype)

def check_x(x: Tensor) -> bool:
return _arrayify(nonlinear_inequality_constraint(x)).item() >= -tolerance
Expand Down Expand Up @@ -615,7 +635,7 @@ def evaluate_feasibility(
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
tolerance: float = CONST_TOL,
tolerance: float | None = None,
) -> Tensor:
r"""Evaluate feasibility of candidate points (within a tolerance).

Expand Down Expand Up @@ -657,6 +677,9 @@ def evaluate_feasibility(
A boolean tensor of shape `batch` indicating if the corresponding candidate of
shape `q x d` is feasible.
"""
if tolerance is None:
tolerance = get_constraint_tolerance(dtype=X.dtype)

is_feasible = torch.ones(X.shape[:-2], device=X.device, dtype=torch.bool)
if inequality_constraints is not None:
for idx, coef, rhs in inequality_constraints:
Expand Down Expand Up @@ -691,3 +714,78 @@ def evaluate_feasibility(
tolerance=tolerance,
)
return is_feasible


def project_to_feasible_space_via_slsqp(
X: Tensor,
bounds: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
) -> Tensor:
"""Project X onto the feasible space by solving a quadratic program.

This uses SLSQP with gradients to solve the quadratic program.
NOTE: A proper specialized QP solver would be a better choice here,
but we'd like to avoid adding dependency on additional packages.
SLSQP should be able to solve this reliably and quickly since the
dimension is typically low and the number of constraints is typically
limited.

Args:
X: A `(batch_shape x) n x d`-dim tensor of inptus.
bounds: A `2 x d`-dim tensor of lower and upper bounds.
inequality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
`coefficients` should be torch tensors. See the docstring of
`make_scipy_linear_constraints` for an example.
equality_constraints: A list of tuples (indices, coefficients, rhs).

Returns:
A `(batch_shape x) n x d`-dim tensor of projected values.
"""
if inequality_constraints is None and equality_constraints is None:
return X
bounds_scipy = make_scipy_bounds(
X=X, lower_bounds=bounds[0], upper_bounds=bounds[1]
)
constraints = make_scipy_linear_constraints(
shapeX=X.shape,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
# Define squared distance objective
X_np = X.flatten().detach().cpu().numpy()

def objective(x: np.ndarray):
return 0.5 * np.sum((x - X_np) ** 2)

def grad_objective(x: np.ndarray):
return x - X_np

x0 = (
columnwise_clamp(X=X, lower=bounds[0], upper=bounds[1], raise_on_violation=True)
.detach()
.cpu()
.numpy()
.flatten()
)
# NOTE: A proper specialized QP solver would be a better choice here,
# but we'd like to avoid adding dependency on additional packages.
# SLSQP should be able to solve this reliably and quickly since the
# dimension is typically low and the number of constraints is typically
# limited.
result = minimize(
fun=objective,
x0=x0,
method="SLSQP",
jac=grad_objective,
bounds=bounds_scipy,
constraints=constraints,
tol=get_constraint_tolerance(dtype=X.dtype),
)

if not result.success:
raise RuntimeError(f"Optimization failed: {result.message}")

return torch.from_numpy(result.x).to(X).view(X.shape)
Loading
Loading