|
35 | 35 | gen_one_shot_kg_initial_conditions, |
36 | 36 | TGenInitialConditions, |
37 | 37 | ) |
38 | | -from botorch.optim.parameter_constraints import evaluate_feasibility |
| 38 | +from botorch.optim.parameter_constraints import ( |
| 39 | + evaluate_feasibility, |
| 40 | + project_to_feasible_space_via_slsqp, |
| 41 | +) |
39 | 42 | from botorch.optim.stopping import ExpMAStoppingCriterion |
40 | 43 | from torch import Tensor |
41 | 44 |
|
@@ -513,15 +516,47 @@ def _optimize_batch_candidates() -> tuple[Tensor, Tensor, list[Warning]]: |
513 | 516 |
|
514 | 517 | # SLSQP can sometimes fail to produce a feasible candidate. Check for |
515 | 518 | # feasibility and error out if necessary. |
| 519 | + # if there are equality constraints, project the candidate to the feasible set |
| 520 | + equality_constraints = gen_kwargs.get("equality_constraints") |
| 521 | + inequality_constraints = gen_kwargs.get("inequality_constraints") |
| 522 | + nonlinear_inequality_constraints = gen_kwargs.get( |
| 523 | + "nonlinear_inequality_constraints" |
| 524 | + ) |
516 | 525 | is_feasible = evaluate_feasibility( |
517 | 526 | X=batch_candidates, |
518 | | - inequality_constraints=gen_kwargs.get("inequality_constraints"), |
519 | | - equality_constraints=gen_kwargs.get("equality_constraints"), |
520 | | - nonlinear_inequality_constraints=gen_kwargs.get( |
521 | | - "nonlinear_inequality_constraints" |
522 | | - ), |
| 527 | + inequality_constraints=inequality_constraints, |
| 528 | + equality_constraints=equality_constraints, |
| 529 | + nonlinear_inequality_constraints=nonlinear_inequality_constraints, |
523 | 530 | ) |
524 | 531 | infeasible = ~is_feasible |
| 532 | + if nonlinear_inequality_constraints is None and infeasible.any(): |
| 533 | + projected_candidates = project_to_feasible_space_via_slsqp( |
| 534 | + X=batch_candidates[infeasible], |
| 535 | + bounds=opt_inputs.bounds, |
| 536 | + equality_constraints=equality_constraints, |
| 537 | + inequality_constraints=inequality_constraints, |
| 538 | + ) |
| 539 | + if opt_inputs.post_processing_func is not None: |
| 540 | + projected_candidates = opt_inputs.post_processing_func(projected_candidates) |
| 541 | + batch_candidates[infeasible] = projected_candidates |
| 542 | + # recompute AF values for projected points |
| 543 | + with torch.no_grad(): |
| 544 | + batch_acq_values[infeasible] = torch.cat( |
| 545 | + [ |
| 546 | + opt_inputs.acq_function(cand) |
| 547 | + for cand in projected_candidates.split(batch_limit, dim=0) |
| 548 | + ], |
| 549 | + dim=0, |
| 550 | + ) |
| 551 | + # re-evaluate feasibility |
| 552 | + is_feasible = evaluate_feasibility( |
| 553 | + X=batch_candidates, |
| 554 | + inequality_constraints=inequality_constraints, |
| 555 | + equality_constraints=equality_constraints, |
| 556 | + nonlinear_inequality_constraints=nonlinear_inequality_constraints, |
| 557 | + ) |
| 558 | + infeasible = ~is_feasible |
| 559 | + |
525 | 560 | if (opt_inputs.return_best_only and (not is_feasible.any())) or infeasible.all(): |
526 | 561 | raise CandidateGenerationError( |
527 | 562 | f"The optimizer produced infeasible candidates. " |
|
0 commit comments