Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions botorch/acquisition/input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.objective import IdentityMCMultiOutputObjective
from botorch.acquisition.multi_objective.parego import qLogNParEGO
from botorch.acquisition.multi_objective.utils import get_default_partitioning_alpha
from botorch.acquisition.objective import (
ConstrainedMCObjective,
Expand Down Expand Up @@ -1115,6 +1116,84 @@ def construct_inputs_qLogNEHVI(
}


@acqf_input_constructor(qLogNParEGO)
def construct_inputs_qLogNParEGO(
model: Model,
training_data: MaybeDict[SupervisedDataset],
scalarization_weights: Optional[Tensor] = None,
objective: Optional[MCMultiOutputObjective] = None,
X_pending: Optional[Tensor] = None,
sampler: Optional[MCSampler] = None,
X_baseline: Optional[Tensor] = None,
prune_baseline: Optional[bool] = True,
cache_root: Optional[bool] = True,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
eta: Union[Tensor, float] = 1e-3,
fat: bool = True,
tau_max: float = TAU_MAX,
tau_relu: float = TAU_RELU,
):
r"""Construct kwargs for the `qLogNoisyExpectedImprovement` constructor.

Args:
model: The model to be used in the acquisition function.
training_data: Dataset(s) used to train the model.
scalarization_weights: A `m`-dim Tensor of weights to be used in the
Chebyshev scalarization. If omitted, samples from the unit simplex.
objective: The MultiOutputMCAcquisitionObjective under which the samples are
evaluated before applying Chebyshev scalarization.
Defaults to `IdentityMultiOutputObjective()`.
X_pending: A `m x d`-dim Tensor of `m` design points that have been
submitted for function evaluation but have not yet been evaluated.
Concatenated into X upon forward call.
sampler: The sampler used to draw base samples. If omitted, uses
the acquisition functions's default sampler.
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
that have already been observed. These points are considered as
the potential best design point. If omitted, checks that all
training_data have the same input features and take the first `X`.
prune_baseline: If True, remove points in `X_baseline` that are
highly unlikely to be the best point. This can significantly
improve performance and is generally recommended.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are considered satisfied if the output is less than zero.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. For more details, on this
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
fat: Toggles the use of the fat-tailed non-linearities to smoothly approximate
the constraints indicator function.
tau_max: Temperature parameter controlling the sharpness of the smooth
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.

Returns:
A dict mapping kwarg names of the constructor to values.
"""
base_inputs = construct_inputs_qLogNEI(
model=model,
training_data=training_data,
objective=objective,
X_pending=X_pending,
sampler=sampler,
X_baseline=X_baseline,
prune_baseline=prune_baseline,
cache_root=cache_root,
constraints=constraints,
eta=eta,
fat=fat,
tau_max=tau_max,
tau_relu=tau_relu,
)
base_inputs.pop("posterior_transform", None)
return {
**base_inputs,
"scalarization_weights": scalarization_weights,
}


@acqf_input_constructor(qMaxValueEntropy)
def construct_inputs_qMES(
model: Model,
Expand Down
147 changes: 147 additions & 0 deletions botorch/acquisition/multi_objective/parego.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, List, Optional, Union

import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement, TAU_MAX, TAU_RELU
from botorch.acquisition.multi_objective.monte_carlo import (
MultiObjectiveMCAcquisitionFunction,
)
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.acquisition.objective import GenericMCObjective
from botorch.models.model import Model
from botorch.posteriors.fully_bayesian import MCMC_DIM
from botorch.sampling.base import MCSampler
from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization
from botorch.utils.sampling import sample_simplex
from botorch.utils.transforms import is_ensemble
from torch import Tensor


class qLogNParEGO(qLogNoisyExpectedImprovement, MultiObjectiveMCAcquisitionFunction):
def __init__(
self,
model: Model,
X_baseline: Tensor,
scalarization_weights: Optional[Tensor] = None,
sampler: Optional[MCSampler] = None,
objective: Optional[MCMultiOutputObjective] = None,
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
X_pending: Optional[Tensor] = None,
eta: Union[Tensor, float] = 1e-3,
fat: bool = True,
prune_baseline: bool = False,
cache_root: bool = True,
tau_relu: float = TAU_RELU,
tau_max: float = TAU_MAX,
) -> None:
r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function
utilizes qLogNEI to compute the expected improvement over Chebyshev
scalarization of the objectives.

This is adapted from qNParEGO proposed in [Daulton2020qehvi]_ to utilize
log-improvement acquisition functions of [Ament2023logei]_. See [Knowles2005]_
for the original ParEGO algorithm.

This implementation assumes maximization of all objectives. If any of the model
outputs are to be minimized, either an `objective` should be used to negate the
model outputs or the `scalarization_weights` should be provided with negative
weights for the outputs to be minimized.

Args:
model: A fitted multi-output model, producing outputs for `m` objectives
and any number of outcome constraints.
NOTE: The model posterior must have a `mean` attribute.
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
that have already been observed. These points are considered as
the potential best design point.
scalarization_weights: A `m`-dim Tensor of weights to be used in the
Chebyshev scalarization. If omitted, samples from the unit simplex.
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
more details.
objective: The MultiOutputMCAcquisitionObjective under which the samples are
evaluated before applying Chebyshev scalarization.
Defaults to `IdentityMultiOutputObjective()`.
constraints: A list of constraint callables which map a Tensor of posterior
samples of dimension `sample_shape x batch-shape x q x m'`-dim to a
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
are satisfied if `constraint(samples) < 0`.
X_pending: A `batch_shape x q' x d`-dim Tensor of `q'` design points
that have points that have been submitted for function evaluation
but have not yet been evaluated. Concatenated into `X` upon
forward call. Copied and set to have no gradient.
eta: Temperature parameter(s) governing the smoothness of the sigmoid
approximation to the constraint indicators. See the docs of
`compute_(log_)smoothed_constraint_indicator` for details.
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
approximation to the ReLU.
prune_baseline: If True, remove points in `X_baseline` that are
highly unlikely to be the best point. This can significantly
improve performance and is generally recommended. In order to
customize pruning parameters, instead manually call
`botorch.acquisition.utils.prune_inferior_points` on `X_baseline`
before instantiating the acquisition function.
cache_root: A boolean indicating whether to cache the root
decomposition over `X_baseline` and use low-rank updates.
tau_max: Temperature parameter controlling the sharpness of the smooth
approximations to max.
tau_relu: Temperature parameter controlling the sharpness of the smooth
approximations to ReLU.
"""
MultiObjectiveMCAcquisitionFunction.__init__(
self,
model=model,
sampler=sampler,
objective=objective,
constraints=constraints,
eta=eta,
)
org_objective = self.objective
# Create the composite objective.
with torch.no_grad():
Y_baseline = org_objective(model.posterior(X_baseline).mean)
if is_ensemble(model):
Y_baseline = torch.mean(Y_baseline, dim=MCMC_DIM)
scalarization_weights = (
scalarization_weights
if scalarization_weights is not None
else sample_simplex(
d=Y_baseline.shape[-1], device=X_baseline.device, dtype=X_baseline.dtype
).view(-1)
)
chebyshev_scalarization = get_chebyshev_scalarization(
weights=scalarization_weights,
Y=Y_baseline,
)
composite_objective = GenericMCObjective(
objective=lambda samples, X=None: chebyshev_scalarization(
org_objective(samples=samples, X=X), X=X
),
)
qLogNoisyExpectedImprovement.__init__(
self,
model=model,
X_baseline=X_baseline,
sampler=sampler,
# This overwrites self.objective with the composite objective.
objective=composite_objective,
X_pending=X_pending,
constraints=constraints,
eta=eta,
fat=fat,
prune_baseline=prune_baseline,
cache_root=cache_root,
tau_max=tau_max,
tau_relu=tau_relu,
)
# Set these after __init__ calls so that they're not overwritten / deleted.
# These are intended mainly for easier debugging & transparency.
self._org_objective: MCMultiOutputObjective = org_objective
self.chebyshev_scalarization: Callable[[Tensor, Optional[Tensor]], Tensor] = (
chebyshev_scalarization
)
self.scalarization_weights: Tensor = scalarization_weights
self.Y_baseline: Tensor = Y_baseline
5 changes: 5 additions & 0 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ Multi-Objective Predictive Entropy Search Acquisition Functions
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search
:members:

ParEGO: Multi-Objective Acquisition Function with Chebyshev Scalarization
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.parego
:members:

The One-Shot Knowledge Gradient
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.knowledge_gradient
Expand Down
121 changes: 121 additions & 0 deletions test/acquisition/multi_objective/test_parego.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, Optional

import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.acquisition.multi_objective.objective import (
IdentityMCMultiOutputObjective,
WeightedMCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.parego import qLogNParEGO
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.utils.testing import BotorchTestCase


class TestqLogNParEGO(BotorchTestCase):
def base_test_parego(
self,
with_constraints: bool = False,
with_scalarization_weights: bool = False,
with_objective: bool = False,
model: Optional[Model] = None,
) -> None:
if with_constraints:
assert with_objective, "Objective must be specified if constraints are."
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
num_objectives = 2
num_constraints = 1 if with_constraints else 0
num_outputs = num_objectives + num_constraints
model = model or SingleTaskGP(
train_X=torch.rand(5, 2, **tkwargs),
train_Y=torch.rand(5, num_outputs, **tkwargs),
)
scalarization_weights = (
torch.rand(num_objectives, **tkwargs)
if with_scalarization_weights
else None
)
objective = (
WeightedMCMultiOutputObjective(
weights=torch.tensor([2.0, -0.5], **tkwargs), outcomes=[0, 1]
)
if with_objective
else None
)
constraints = [lambda samples: samples[..., -1]] if with_constraints else None
acqf = qLogNParEGO(
model=model,
X_baseline=torch.rand(3, 2, **tkwargs),
scalarization_weights=scalarization_weights,
objective=objective,
constraints=constraints,
prune_baseline=True,
)
self.assertEqual(acqf.Y_baseline.shape, torch.Size([3, 2]))
# Scalarization weights should be set if given and sampled otherwise.
if scalarization_weights is not None:
self.assertIs(acqf.scalarization_weights, scalarization_weights)
else:
self.assertEqual(
acqf.scalarization_weights.shape, torch.Size([num_objectives])
)
# Should sum to 1 since they're sampled from simplex.
self.assertAlmostEqual(acqf.scalarization_weights.sum().item(), 1.0)
# Original objective should default to identity.
if with_objective:
self.assertIs(acqf._org_objective, objective)
else:
self.assertIsInstance(acqf._org_objective, IdentityMCMultiOutputObjective)
# Acqf objective should be the chebyshev scalarization compounded
# with the original objective.
test_samples = torch.rand(32, 5, num_outputs, **tkwargs)
expected_objective = acqf.chebyshev_scalarization(
acqf._org_objective(test_samples)
)
self.assertEqual(expected_objective.shape, torch.Size([32, 5]))
self.assertAllClose(acqf.objective(test_samples), expected_objective)
# Evaluate the acquisition function.
self.assertEqual(acqf(torch.rand(5, 2, **tkwargs)).shape, torch.Size([1]))
test_X = torch.rand(32, 5, 2, **tkwargs)
acqf_val = acqf(test_X)
self.assertEqual(acqf_val.shape, torch.Size([32]))
# Check that we're indeed using qLogNEI.
self.assertIs(
acqf.forward.__code__, qLogNoisyExpectedImprovement.forward.__code__
)
self.assertAllClose(
acqf_val, qLogNoisyExpectedImprovement.forward(acqf, X=test_X)
)

def test_parego_simple(self) -> None:
self.base_test_parego()

def test_parego_with_constraints_objective_weights(self) -> None:
self.base_test_parego(
with_constraints=True, with_objective=True, with_scalarization_weights=True
)

def test_parego_with_ensemble_model(self) -> None:
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
models = []
for _ in range(2):
model = SaasFullyBayesianSingleTaskGP(
train_X=torch.rand(5, 2, **tkwargs),
train_Y=torch.randn(5, 1, **tkwargs),
train_Yvar=torch.rand(5, 1, **tkwargs) * 0.05,
)
mcmc_samples = {
"lengthscale": torch.rand(4, 1, 2, **tkwargs),
"outputscale": torch.rand(4, **tkwargs),
"mean": torch.randn(4, **tkwargs),
}
model.load_mcmc_samples(mcmc_samples)
models.append(model)
self.base_test_parego(model=ModelListGP(*models))
Loading