Skip to content

Commit 06b6ed0

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix input constructor for LogProbabilityOfFeasibility
Summary: X-link: facebook/Ax#4080 Constraints were being passed in the wrong format (a list of callables rather than a dictionary of indices to bounds). This also removes `convert_to_target_pre_hook`, which only works for args and doesn't work for kwargs. Previously `LogProbabilityOfFeasibility` would error out if called with a kwarg---e.g. `acqf(X=X)` dtype and device are set in `_mean_and_sigma` for these AFs anyway. Differential Revision: D79281274
1 parent c8edbe5 commit 06b6ed0

File tree

4 files changed

+53
-29
lines changed

4 files changed

+53
-29
lines changed

botorch/acquisition/analytic.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from botorch.utils.safe_math import log1mexp, logmeanexp
3838
from botorch.utils.transforms import (
3939
average_over_ensemble_models,
40-
convert_to_target_pre_hook,
4140
t_batch_mode_transform,
4241
)
4342
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
@@ -100,7 +99,7 @@ def _mean_and_sigma(
10099
posterior. Removes the last two dimensions if they have size one. Only
101100
returns a single tensor of means if compute_sigma is True.
102101
"""
103-
self.to(device=X.device) # ensures buffers / parameters are on the same device
102+
self.to(X) # ensures buffers / parameters are on the same device and dtype
104103
posterior = self.model.posterior(
105104
X=X, posterior_transform=self.posterior_transform
106105
)
@@ -584,7 +583,6 @@ def __init__(
584583
self.objective_index = objective_index
585584
self.register_buffer("best_f", torch.as_tensor(best_f))
586585
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
587-
self.register_forward_pre_hook(convert_to_target_pre_hook)
588586

589587
@t_batch_mode_transform(expected_q=1)
590588
@average_over_ensemble_models
@@ -638,9 +636,7 @@ class LogProbabilityOfFeasibility(
638636
_log: bool = True
639637

640638
def __init__(
641-
self,
642-
model: Model,
643-
constraints: dict[int, tuple[float | None, float | None]],
639+
self, model: Model, constraints: dict[int, tuple[float | None, float | None]]
644640
) -> None:
645641
r"""Analytic Log Probability of Feasibility.
646642
@@ -654,7 +650,6 @@ def __init__(
654650
AcquisitionFunction.__init__(self, model=model)
655651
self.posterior_transform = None
656652
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
657-
self.register_forward_pre_hook(convert_to_target_pre_hook)
658653

659654
@t_batch_mode_transform(expected_q=1)
660655
@average_over_ensemble_models
@@ -730,7 +725,6 @@ def __init__(
730725
self.objective_index = objective_index
731726
self.register_buffer("best_f", torch.as_tensor(best_f))
732727
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
733-
self.register_forward_pre_hook(convert_to_target_pre_hook)
734728

735729
@t_batch_mode_transform(expected_q=1)
736730
@average_over_ensemble_models

botorch/acquisition/input_constructors.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@
9797
get_optimal_samples,
9898
project_to_target_fidelity,
9999
)
100-
from botorch.exceptions.errors import UnsupportedError
100+
from botorch.exceptions.errors import BotorchError, UnsupportedError
101101
from botorch.models.cost import AffineFidelityCostModel
102102
from botorch.models.deterministic import FixedSingleSampleModel
103103
from botorch.models.gpytorch import GPyTorchModel
@@ -226,6 +226,12 @@ def allow_only_specific_variable_kwargs(f: Callable[..., T]) -> Callable[..., T]
226226
# Used in input constructors for some lookahead acquisition functions
227227
# such as qKnowledgeGradient.
228228
"bounds",
229+
# Needed for LogProbabilityOfFeasibility
230+
"constraints_tuple",
231+
"posterior_transform",
232+
# not used by analytic acquisition functions
233+
"objective",
234+
"constraints",
229235
}
230236

231237
def g(*args: Any, **kwargs: Any) -> T:
@@ -338,28 +344,42 @@ def construct_inputs_best_f(
338344
}
339345

340346

341-
@acqf_input_constructor(
342-
LogProbabilityOfFeasibility,
343-
)
347+
@acqf_input_constructor(LogProbabilityOfFeasibility)
344348
def construct_inputs_pof(
345-
model: Model,
346-
constraints: dict[int, tuple[float | None, float | None]],
349+
model: Model, constraints_tuple: tuple[Tensor, Tensor]
347350
) -> dict[str, Any]:
348351
r"""Construct kwargs for the log probability of feasibility acquisition function.
349352
350353
Args:
351354
model: The model to be used in the acquisition function.
352-
constraints: A dictionary of the form `{i: [lower, upper]}`, where `i` is the
353-
output index, and `lower` and `upper` are lower and upper bounds on that
354-
output (resp. interpreted as -Inf / Inf if None).
355+
constraints_tuple: A tuple of `(A, b)`. For `k` outcome constraints
356+
and `m` outputs at `f(x)``, `A` is `k x m` and `b` is `k x 1` such
357+
that `A f(x) <= b`.
358+
355359
356360
Returns:
357361
A dict mapping kwarg names of the constructor to values.
358362
"""
359-
return {
360-
"model": model,
361-
"constraints": constraints,
362-
}
363+
# Construct a dictionary of the form `{i: [lower, upper]}`,
364+
# where `i` is the output index, and `lower` and `upper` are
365+
# lower and upper bounds on that output (resp. interpreted
366+
# as -Inf / Inf if None).
367+
weights, bounds = constraints_tuple
368+
constraints_dict = {}
369+
for w, b in zip(weights, bounds):
370+
nonzero_w = w.nonzero()
371+
if nonzero_w.numel() != 1:
372+
raise BotorchError(
373+
"LogProbabilityOfFeasibility only support constraints on single"
374+
" outcomes."
375+
)
376+
i = nonzero_w.item()
377+
w_i = w[i]
378+
is_ub = torch.sign(w_i) == 1.0
379+
b = b.item()
380+
bounds = (None, b / w_i) if is_ub else (b / w_i, None)
381+
constraints_dict[i] = bounds
382+
return {"model": model, "constraints": constraints_dict}
363383

364384

365385
@acqf_input_constructor(UpperConfidenceBound)

botorch/utils/transforms.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,3 @@ def match_batch_shape(X: Tensor, Y: Tensor) -> Tensor:
422422
423423
"""
424424
return X.expand(X.shape[: -(Y.dim())] + Y.shape[:-2] + X.shape[-2:])
425-
426-
427-
def convert_to_target_pre_hook(module, *args):
428-
r"""Pre-hook for automatically calling `.to(X)` on module prior to `forward`"""
429-
module.to(args[0][0])

test/acquisition/test_input_constructors.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
expand_trace_observations,
111111
project_to_target_fidelity,
112112
)
113-
from botorch.exceptions.errors import UnsupportedError
113+
from botorch.exceptions.errors import BotorchError, UnsupportedError
114114
from botorch.models import MultiTaskGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
115115
from botorch.models.deterministic import FixedSingleSampleModel
116116
from botorch.models.model_list_gp_regression import ModelListGP
@@ -636,11 +636,26 @@ def constraint(Y: Tensor) -> Tensor:
636636
def test_construct_inputs_LogPOF(self) -> None:
637637
c = get_acqf_input_constructor(LogProbabilityOfFeasibility)
638638
mock_model = self.mock_model
639-
constraints = {1: [None, 0]}
640-
kwargs = c(model=mock_model, constraints=constraints)
639+
constraints_tuple = [torch.tensor([[0.0, 1.0]]), torch.tensor([[2.0]])]
640+
constraints = {1: (None, 2.0)}
641+
kwargs = c(model=mock_model, constraints_tuple=constraints_tuple)
641642
self.assertEqual(set(kwargs.keys()), {"model", "constraints"})
642643
self.assertIs(kwargs["model"], mock_model)
643644
self.assertEqual(kwargs["constraints"], constraints)
645+
constraints_tuple = [torch.tensor([[0.0, -1.0]]), torch.tensor([[-2.0]])]
646+
kwargs = c(model=mock_model, constraints_tuple=constraints_tuple)
647+
constraints = {1: (2.0, None)}
648+
self.assertEqual(kwargs["constraints"], constraints)
649+
# test that constraints on multiple outcomes raises an exception
650+
with self.assertRaisesRegex(
651+
BotorchError,
652+
"LogProbabilityOfFeasibility only support constraints on single"
653+
" outcomes.",
654+
):
655+
c(
656+
model=mock_model,
657+
constraints_tuple=[torch.tensor([[1.0, 1.0]]), torch.tensor([[2.0]])],
658+
)
644659

645660
def test_construct_inputs_qEI(self) -> None:
646661
c = get_acqf_input_constructor(qExpectedImprovement)

0 commit comments

Comments
 (0)