Skip to content

Commit 544973a

Browse files
sdaultonfacebook-github-bot
authored andcommitted
project point to feasible space via quadratic programming
Summary: see title. This is particularly useful for resolving numerical issues with Ax when it checks parameter constraints. Differential Revision: D82328877
1 parent eba2dce commit 544973a

File tree

4 files changed

+610
-10
lines changed

4 files changed

+610
-10
lines changed

botorch/optim/optimize.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
gen_one_shot_kg_initial_conditions,
3636
TGenInitialConditions,
3737
)
38-
from botorch.optim.parameter_constraints import evaluate_feasibility
38+
from botorch.optim.parameter_constraints import (
39+
evaluate_feasibility,
40+
project_to_feasible_space_via_slsqp,
41+
)
3942
from botorch.optim.stopping import ExpMAStoppingCriterion
4043
from torch import Tensor
4144

@@ -513,15 +516,47 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]:
513516

514517
# SLSQP can sometimes fail to produce a feasible candidate. Check for
515518
# feasibility and error out if necessary.
519+
# if there are equality constraints, project the candidate to the feasible set
520+
equality_constraints = gen_kwargs.get("equality_constraints")
521+
inequality_constraints = gen_kwargs.get("inequality_constraints")
522+
nonlinear_inequality_constraints = gen_kwargs.get(
523+
"nonlinear_inequality_constraints"
524+
)
516525
is_feasible = evaluate_feasibility(
517526
X=batch_candidates,
518-
inequality_constraints=gen_kwargs.get("inequality_constraints"),
519-
equality_constraints=gen_kwargs.get("equality_constraints"),
520-
nonlinear_inequality_constraints=gen_kwargs.get(
521-
"nonlinear_inequality_constraints"
522-
),
527+
inequality_constraints=inequality_constraints,
528+
equality_constraints=equality_constraints,
529+
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
523530
)
524531
infeasible = ~is_feasible
532+
if nonlinear_inequality_constraints is None and infeasible.any():
533+
projected_candidates = project_to_feasible_space_via_slsqp(
534+
X=batch_candidates[infeasible],
535+
bounds=opt_inputs.bounds,
536+
equality_constraints=equality_constraints,
537+
inequality_constraints=inequality_constraints,
538+
)
539+
if opt_inputs.post_processing_func is not None:
540+
projected_candidates = opt_inputs.post_processing_func(projected_candidates)
541+
batch_candidates[infeasible] = projected_candidates
542+
# recompute AF values for projected points
543+
with torch.no_grad():
544+
batch_acq_values[infeasible] = torch.cat(
545+
[
546+
opt_inputs.acq_function(cand)
547+
for cand in projected_candidates.split(batch_limit, dim=0)
548+
],
549+
dim=0,
550+
)
551+
# re-evaluate feasibility
552+
is_feasible = evaluate_feasibility(
553+
X=batch_candidates,
554+
inequality_constraints=inequality_constraints,
555+
equality_constraints=equality_constraints,
556+
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
557+
)
558+
infeasible = ~is_feasible
559+
525560
if (opt_inputs.return_best_only and (not is_feasible.any())) or infeasible.all():
526561
raise CandidateGenerationError(
527562
f"The optimizer produced infeasible candidates. "

botorch/optim/parameter_constraints.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
import numpy.typing as npt
1919
import torch
2020
from botorch.exceptions.errors import CandidateGenerationError, UnsupportedError
21-
from scipy.optimize import Bounds
21+
from botorch.optim.utils import columnwise_clamp
22+
from scipy.optimize import Bounds, minimize
2223
from torch import Tensor
2324

2425

2526
ScipyConstraintDict = dict[
2627
str, Union[str, Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]]
2728
]
28-
CONST_TOL = 1e-6
29+
CONST_TOL = 1e-8
2930

3031

3132
def make_scipy_bounds(
@@ -691,3 +692,69 @@ def evaluate_feasibility(
691692
tolerance=tolerance,
692693
)
693694
return is_feasible
695+
696+
697+
def project_to_feasible_space_via_slsqp(
698+
X: Tensor,
699+
bounds: Tensor,
700+
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
701+
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
702+
) -> Tensor:
703+
"""Project X onto the feasible space by solving a quadratic program.
704+
705+
This uses SLSQP with gradients to solve the quadratic program.
706+
707+
Args:
708+
X: A `(batch_shape x) n x d`-dim tensor of inptus.
709+
bounds: A `2 x d`-dim tensor of lower and upper bounds.
710+
inequality_constraints: A list of tuples (indices, coefficients, rhs),
711+
with each tuple encoding an inequality constraint of the form
712+
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
713+
`coefficients` should be torch tensors. See the docstring of
714+
`make_scipy_linear_constraints` for an example. Only intra-point
715+
constraints are supported and `indices` should be a 1-d tensor.
716+
equality_constraints: A list of tuples (indices, coefficients, rhs).
717+
718+
Returns:
719+
A `(batch_shape x) n x d`-dim tensor of projected values.
720+
"""
721+
if inequality_constraints is None and equality_constraints is None:
722+
return X
723+
bounds_scipy = make_scipy_bounds(
724+
X=X, lower_bounds=bounds[0], upper_bounds=bounds[1]
725+
)
726+
constraints = make_scipy_linear_constraints(
727+
shapeX=X.shape,
728+
inequality_constraints=inequality_constraints,
729+
equality_constraints=equality_constraints,
730+
)
731+
# Define squared distance objective
732+
X_np = X.flatten().detach().cpu().numpy()
733+
734+
def objective(x: np.ndarray):
735+
return 0.5 * np.sum((x - X_np) ** 2)
736+
737+
def grad_objective(x: np.ndarray):
738+
return x - X_np
739+
740+
x0 = (
741+
columnwise_clamp(X=X, lower=bounds[0], upper=bounds[1], raise_on_violation=True)
742+
.detach()
743+
.cpu()
744+
.numpy()
745+
.flatten()
746+
)
747+
result = minimize(
748+
fun=objective,
749+
x0=x0,
750+
method="SLSQP",
751+
jac=grad_objective,
752+
bounds=bounds_scipy,
753+
constraints=constraints,
754+
tol=CONST_TOL,
755+
)
756+
757+
if not result.success:
758+
raise RuntimeError(f"Optimization failed: {result.message}")
759+
760+
return torch.from_numpy(result.x).to(X).view(X.shape)

0 commit comments

Comments
 (0)