Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from botorch.utils.safe_math import log1mexp, logmeanexp
from botorch.utils.transforms import (
average_over_ensemble_models,
convert_to_target_pre_hook,
t_batch_mode_transform,
)
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
Expand Down Expand Up @@ -100,7 +99,7 @@ def _mean_and_sigma(
posterior. Removes the last two dimensions if they have size one. Only
returns a single tensor of means if compute_sigma is True.
"""
self.to(device=X.device) # ensures buffers / parameters are on the same device
self.to(X) # ensures buffers / parameters are on the same device and dtype
posterior = self.model.posterior(
X=X, posterior_transform=self.posterior_transform
)
Expand Down Expand Up @@ -584,7 +583,6 @@ def __init__(
self.objective_index = objective_index
self.register_buffer("best_f", torch.as_tensor(best_f))
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
self.register_forward_pre_hook(convert_to_target_pre_hook)

@t_batch_mode_transform(expected_q=1)
@average_over_ensemble_models
Expand Down Expand Up @@ -638,9 +636,7 @@ class LogProbabilityOfFeasibility(
_log: bool = True

def __init__(
self,
model: Model,
constraints: dict[int, tuple[float | None, float | None]],
self, model: Model, constraints: dict[int, tuple[float | None, float | None]]
) -> None:
r"""Analytic Log Probability of Feasibility.

Expand All @@ -654,7 +650,6 @@ def __init__(
AcquisitionFunction.__init__(self, model=model)
self.posterior_transform = None
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
self.register_forward_pre_hook(convert_to_target_pre_hook)

@t_batch_mode_transform(expected_q=1)
@average_over_ensemble_models
Expand Down Expand Up @@ -730,7 +725,6 @@ def __init__(
self.objective_index = objective_index
self.register_buffer("best_f", torch.as_tensor(best_f))
ConstrainedAnalyticAcquisitionFunctionMixin.__init__(self, constraints)
self.register_forward_pre_hook(convert_to_target_pre_hook)

@t_batch_mode_transform(expected_q=1)
@average_over_ensemble_models
Expand Down
47 changes: 34 additions & 13 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
get_optimal_samples,
project_to_target_fidelity,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.errors import BotorchError, UnsupportedError
from botorch.models.cost import AffineFidelityCostModel
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.gpytorch import GPyTorchModel
Expand Down Expand Up @@ -226,6 +226,13 @@ def allow_only_specific_variable_kwargs(f: Callable[..., T]) -> Callable[..., T]
# Used in input constructors for some lookahead acquisition functions
# such as qKnowledgeGradient.
"bounds",
# Needed for LogProbabilityOfFeasibility
# and LogConstrainedExpectedImprovement
"constraints_tuple",
"posterior_transform",
# not used by analytic acquisition functions
"objective",
"constraints",
}

def g(*args: Any, **kwargs: Any) -> T:
Expand Down Expand Up @@ -338,28 +345,42 @@ def construct_inputs_best_f(
}


@acqf_input_constructor(
LogProbabilityOfFeasibility,
)
@acqf_input_constructor(LogProbabilityOfFeasibility)
def construct_inputs_pof(
model: Model,
constraints: dict[int, tuple[float | None, float | None]],
model: Model, constraints_tuple: tuple[Tensor, Tensor]
) -> dict[str, Any]:
r"""Construct kwargs for the log probability of feasibility acquisition function.

Args:
model: The model to be used in the acquisition function.
constraints: A dictionary of the form `{i: [lower, upper]}`, where `i` is the
output index, and `lower` and `upper` are lower and upper bounds on that
output (resp. interpreted as -Inf / Inf if None).
constraints_tuple: A tuple of `(A, b)`. For `k` outcome constraints
and `m` outputs at `f(x)``, `A` is `k x m` and `b` is `k x 1` such
that `A f(x) <= b`.


Returns:
A dict mapping kwarg names of the constructor to values.
"""
return {
"model": model,
"constraints": constraints,
}
# Construct a dictionary of the form `{i: [lower, upper]}`,
# where `i` is the output index, and `lower` and `upper` are
# lower and upper bounds on that output (resp. interpreted
# as -Inf / Inf if None).
weights, bounds = constraints_tuple
constraints_dict = {}
for w, b in zip(weights, bounds):
nonzero_w = w.nonzero()
if nonzero_w.numel() != 1:
raise BotorchError(
"LogProbabilityOfFeasibility only support constraints on single"
" outcomes."
)
i = nonzero_w.item()
w_i = w[i]
is_ub = torch.sign(w_i) == 1.0
b = b.item()
bounds = (None, b / w_i) if is_ub else (b / w_i, None)
constraints_dict[i] = bounds
return {"model": model, "constraints": constraints_dict}


@acqf_input_constructor(UpperConfidenceBound)
Expand Down
5 changes: 0 additions & 5 deletions botorch/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,3 @@ def match_batch_shape(X: Tensor, Y: Tensor) -> Tensor:

"""
return X.expand(X.shape[: -(Y.dim())] + Y.shape[:-2] + X.shape[-2:])


def convert_to_target_pre_hook(module, *args):
r"""Pre-hook for automatically calling `.to(X)` on module prior to `forward`"""
module.to(args[0][0])
56 changes: 48 additions & 8 deletions test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
expand_trace_observations,
project_to_target_fidelity,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.errors import BotorchError, UnsupportedError
from botorch.models import MultiTaskGP, SaasFullyBayesianSingleTaskGP, SingleTaskGP
from botorch.models.deterministic import FixedSingleSampleModel
from botorch.models.model_list_gp_regression import ModelListGP
Expand Down Expand Up @@ -636,11 +636,26 @@ def constraint(Y: Tensor) -> Tensor:
def test_construct_inputs_LogPOF(self) -> None:
c = get_acqf_input_constructor(LogProbabilityOfFeasibility)
mock_model = self.mock_model
constraints = {1: [None, 0]}
kwargs = c(model=mock_model, constraints=constraints)
constraints_tuple = [torch.tensor([[0.0, 1.0]]), torch.tensor([[2.0]])]
constraints = {1: (None, 2.0)}
kwargs = c(model=mock_model, constraints_tuple=constraints_tuple)
self.assertEqual(set(kwargs.keys()), {"model", "constraints"})
self.assertIs(kwargs["model"], mock_model)
self.assertEqual(kwargs["constraints"], constraints)
constraints_tuple = [torch.tensor([[0.0, -1.0]]), torch.tensor([[-2.0]])]
kwargs = c(model=mock_model, constraints_tuple=constraints_tuple)
constraints = {1: (2.0, None)}
self.assertEqual(kwargs["constraints"], constraints)
# test that constraints on multiple outcomes raises an exception
with self.assertRaisesRegex(
BotorchError,
"LogProbabilityOfFeasibility only support constraints on single"
" outcomes.",
):
c(
model=mock_model,
constraints_tuple=[torch.tensor([[1.0, 1.0]]), torch.tensor([[2.0]])],
)

def test_construct_inputs_qEI(self) -> None:
c = get_acqf_input_constructor(qExpectedImprovement)
Expand Down Expand Up @@ -1781,6 +1796,9 @@ class TestInstantiationFromInputConstructor(InputConstructorBaseTestCase):
def setUp(self, suppress_input_warnings: bool = True) -> None:
super().setUp(suppress_input_warnings=suppress_input_warnings)
# {key: (list of acquisition functions, arguments they accept)}
constraints_tuple_dict = {
"constraints_tuple": (torch.tensor([[0.0, 1.0]]), torch.tensor([[2.0]])),
}
self.cases = {
"PosteriorMean-type": (
[
Expand All @@ -1789,7 +1807,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
qUpperConfidenceBound,
qLowerConfidenceBound,
],
{"model": self.mock_model},
{"model": self.mock_model, **constraints_tuple_dict},
),
}
st_soo_model = SingleTaskGP(
Expand All @@ -1811,20 +1829,32 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
qLogNoisyExpectedImprovement,
qProbabilityOfImprovement,
],
{"model": st_soo_model, "training_data": self.blockX_blockY},
{
"model": st_soo_model,
"training_data": self.blockX_blockY,
**constraints_tuple_dict,
},
)

self.cases["LogPoF"] = (
[LogProbabilityOfFeasibility],
{"model": st_soo_model, "constraints": {0: [-5, 5]}},
{
"model": st_soo_model,
"constraints": {0: [-5, 5]},
**constraints_tuple_dict,
},
)

def constraint(X: Tensor) -> Tensor:
return X[..., 0].abs() - 5

self.cases["qLogPoF"] = (
[qLogProbabilityOfFeasibility],
{"model": st_soo_model, "constraints": [constraint]},
{
"model": st_soo_model,
"constraints": [constraint],
**constraints_tuple_dict,
},
)

bounds = torch.ones((1, 2))
Expand All @@ -1835,6 +1865,7 @@ def constraint(X: Tensor) -> Tensor:
"model": kg_model,
"training_data": self.blockX_blockY,
"bounds": bounds,
**constraints_tuple_dict,
},
)
self.cases["MF look-ahead"] = (
Expand All @@ -1845,6 +1876,7 @@ def constraint(X: Tensor) -> Tensor:
"bounds": bounds,
"target_fidelities": {0: 0.987},
"num_fantasies": 30,
**constraints_tuple_dict,
},
)
bounds = torch.ones((2, 2))
Expand All @@ -1857,6 +1889,7 @@ def constraint(X: Tensor) -> Tensor:
"bounds": bounds,
"target_fidelities": {0: 0.987},
"num_fantasies": 30,
**constraints_tuple_dict,
},
)

Expand All @@ -1877,6 +1910,7 @@ def constraint(X: Tensor) -> Tensor:
"model": st_moo_model,
"objective_thresholds": objective_thresholds,
"training_data": self.blockX_blockY,
**constraints_tuple_dict,
},
)

Expand All @@ -1893,6 +1927,7 @@ def constraint(X: Tensor) -> Tensor:
"training_data": self.blockX_blockY,
"bounds": bounds,
"objective_thresholds": objective_thresholds,
**constraints_tuple_dict,
},
)
self.cases["MF HV Look-ahead"] = (
Expand All @@ -1904,6 +1939,7 @@ def constraint(X: Tensor) -> Tensor:
"target_fidelities": {0: 0.987},
"num_fantasies": 30,
"objective_thresholds": objective_thresholds,
**constraints_tuple_dict,
},
)

Expand All @@ -1913,13 +1949,14 @@ def constraint(X: Tensor) -> Tensor:

self.cases["EUBO"] = (
[AnalyticExpectedUtilityOfBestOption, qExpectedUtilityOfBestOption],
{"model": st_moo_model, "pref_model": pref_model},
{"model": st_moo_model, "pref_model": pref_model, **constraints_tuple_dict},
)
self.cases["qJES"] = (
[qJointEntropySearch],
{
"model": SingleTaskGP(self.blockX_blockY[0].X, self.blockX_blockY[0].Y),
"bounds": self.bounds,
**constraints_tuple_dict,
},
)
self.cases["qSimpleRegret"] = (
Expand All @@ -1928,6 +1965,7 @@ def constraint(X: Tensor) -> Tensor:
"model": SingleTaskGP(self.blockX_blockY[0].X, self.blockX_blockY[0].Y),
"training_data": self.blockX_blockY,
"objective": LinearMCObjective(torch.rand(2)),
**constraints_tuple_dict,
},
)
self.cases["BayesianActiveLearning"] = (
Expand All @@ -1936,6 +1974,7 @@ def constraint(X: Tensor) -> Tensor:
"model": SaasFullyBayesianSingleTaskGP(
self.blockX_blockY[0].X, self.blockX_blockY[0].Y
),
**constraints_tuple_dict,
},
)
self.cases["ActiveLearning"] = (
Expand All @@ -1944,6 +1983,7 @@ def constraint(X: Tensor) -> Tensor:
"model": SingleTaskGP(self.blockX_blockY[0].X, self.blockX_blockY[0].Y),
"training_data": self.blockX_blockY,
"bounds": self.bounds,
**constraints_tuple_dict,
},
)

Expand Down
2 changes: 1 addition & 1 deletion test/test_utils/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_decorator(self) -> None:
warnings.simplefilter("ignore", category=BadInitialCandidatesWarning)
cand, value = optimize_acqf(
acq_function=acqf,
bounds=torch.tensor([[-2.0], [2.0]]),
bounds=torch.tensor([[-2.0], [2.0]], dtype=torch.double),
q=1,
num_restarts=32,
raw_samples=16,
Expand Down