Skip to content

Commit 8c93744

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Implement qLogNParEGO (#2364)
Summary: Adds an implementation of qLogNParEGO that is compatible with Ax MBM. This constructs the Chebyshev scalarization before deferring to qLogNEI for remaining computations. The construction of the Chebyshev objective mirrors what was done in `_get_acqusition_func` for the legacy Ax model. Differential Revision: D58122015
1 parent 5fbbf0e commit 8c93744

File tree

4 files changed

+355
-0
lines changed

4 files changed

+355
-0
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
qLogNoisyExpectedHypervolumeImprovement,
7777
)
7878
from botorch.acquisition.multi_objective.objective import IdentityMCMultiOutputObjective
79+
from botorch.acquisition.multi_objective.parego import qLogNParEGO
7980
from botorch.acquisition.multi_objective.utils import get_default_partitioning_alpha
8081
from botorch.acquisition.objective import (
8182
ConstrainedMCObjective,
@@ -1115,6 +1116,84 @@ def construct_inputs_qLogNEHVI(
11151116
}
11161117

11171118

1119+
@acqf_input_constructor(qLogNParEGO)
1120+
def construct_inputs_qLogNParEGO(
1121+
model: Model,
1122+
training_data: MaybeDict[SupervisedDataset],
1123+
scalarization_weights: Optional[Tensor] = None,
1124+
objective: Optional[MCMultiOutputObjective] = None,
1125+
X_pending: Optional[Tensor] = None,
1126+
sampler: Optional[MCSampler] = None,
1127+
X_baseline: Optional[Tensor] = None,
1128+
prune_baseline: Optional[bool] = True,
1129+
cache_root: Optional[bool] = True,
1130+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
1131+
eta: Union[Tensor, float] = 1e-3,
1132+
fat: bool = True,
1133+
tau_max: float = TAU_MAX,
1134+
tau_relu: float = TAU_RELU,
1135+
):
1136+
r"""Construct kwargs for the `qLogNoisyExpectedImprovement` constructor.
1137+
1138+
Args:
1139+
model: The model to be used in the acquisition function.
1140+
training_data: Dataset(s) used to train the model.
1141+
scalarization_weights: A `m`-dim Tensor of weights to be used in the
1142+
Chebyshev scalarization. If omitted, samples from the unit simplex.
1143+
objective: The MultiOutputMCAcquisitionObjective under which the samples are
1144+
evaluated before applying Chebyshev scalarization.
1145+
Defaults to `IdentityMultiOutputObjective()`.
1146+
X_pending: A `m x d`-dim Tensor of `m` design points that have been
1147+
submitted for function evaluation but have not yet been evaluated.
1148+
Concatenated into X upon forward call.
1149+
sampler: The sampler used to draw base samples. If omitted, uses
1150+
the acquisition functions's default sampler.
1151+
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
1152+
that have already been observed. These points are considered as
1153+
the potential best design point. If omitted, checks that all
1154+
training_data have the same input features and take the first `X`.
1155+
prune_baseline: If True, remove points in `X_baseline` that are
1156+
highly unlikely to be the best point. This can significantly
1157+
improve performance and is generally recommended.
1158+
constraints: A list of constraint callables which map a Tensor of posterior
1159+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
1160+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
1161+
are considered satisfied if the output is less than zero.
1162+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
1163+
approximation to the constraint indicators. For more details, on this
1164+
parameter, see the docs of `compute_smoothed_feasibility_indicator`.
1165+
fat: Toggles the use of the fat-tailed non-linearities to smoothly approximate
1166+
the constraints indicator function.
1167+
tau_max: Temperature parameter controlling the sharpness of the smooth
1168+
approximations to max.
1169+
tau_relu: Temperature parameter controlling the sharpness of the smooth
1170+
approximations to ReLU.
1171+
1172+
Returns:
1173+
A dict mapping kwarg names of the constructor to values.
1174+
"""
1175+
base_inputs = construct_inputs_qLogNEI(
1176+
model=model,
1177+
training_data=training_data,
1178+
objective=objective,
1179+
X_pending=X_pending,
1180+
sampler=sampler,
1181+
X_baseline=X_baseline,
1182+
prune_baseline=prune_baseline,
1183+
cache_root=cache_root,
1184+
constraints=constraints,
1185+
eta=eta,
1186+
fat=fat,
1187+
tau_max=tau_max,
1188+
tau_relu=tau_relu,
1189+
)
1190+
base_inputs.pop("posterior_transform", None)
1191+
return {
1192+
**base_inputs,
1193+
"scalarization_weights": scalarization_weights,
1194+
}
1195+
1196+
11181197
@acqf_input_constructor(qMaxValueEntropy)
11191198
def construct_inputs_qMES(
11201199
model: Model,
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Callable, List, Optional, Union
7+
8+
import torch
9+
from botorch.acquisition.logei import qLogNoisyExpectedImprovement, TAU_MAX, TAU_RELU
10+
from botorch.acquisition.multi_objective.monte_carlo import (
11+
MultiObjectiveMCAcquisitionFunction,
12+
)
13+
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
14+
from botorch.acquisition.objective import GenericMCObjective
15+
from botorch.models.model import Model
16+
from botorch.posteriors.fully_bayesian import MCMC_DIM
17+
from botorch.sampling.base import MCSampler
18+
from botorch.utils.multi_objective.scalarization import get_chebyshev_scalarization
19+
from botorch.utils.sampling import sample_simplex
20+
from botorch.utils.transforms import is_ensemble
21+
from torch import Tensor
22+
23+
24+
class qLogNParEGO(qLogNoisyExpectedImprovement, MultiObjectiveMCAcquisitionFunction):
25+
def __init__(
26+
self,
27+
model: Model,
28+
X_baseline: Tensor,
29+
scalarization_weights: Optional[Tensor] = None,
30+
sampler: Optional[MCSampler] = None,
31+
objective: Optional[MCMultiOutputObjective] = None,
32+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
33+
X_pending: Optional[Tensor] = None,
34+
eta: Union[Tensor, float] = 1e-3,
35+
fat: bool = True,
36+
prune_baseline: bool = False,
37+
cache_root: bool = True,
38+
tau_relu: float = TAU_RELU,
39+
tau_max: float = TAU_MAX,
40+
) -> None:
41+
r"""q-LogNParEGO supporting m >= 2 outcomes. This acquisition function
42+
utilizes qLogNEI to compute the expected improvement over Chebyshev
43+
scalarization of the objectives.
44+
45+
This is adapted from qNParEGO proposed in [Daulton2020qehvi]_ to utilize
46+
log-improvement acquisition functions of [Ament2023logei]_. See [Knowles2005]_
47+
for the original ParEGO algorithm.
48+
49+
This implementation assumes maximization of all objectives. If any of the model
50+
outputs are to be minimized, either an `objective` should be used to negate the
51+
model outputs or the `scalarization_weights` should be provided with negative
52+
weights for the outputs to be minimized.
53+
54+
Args:
55+
model: A fitted multi-output model, producing outputs for `m` objectives
56+
and any number of outcome constraints.
57+
NOTE: The model posterior must have a `mean` attribute.
58+
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
59+
that have already been observed. These points are considered as
60+
the potential best design point.
61+
scalarization_weights: A `m`-dim Tensor of weights to be used in the
62+
Chebyshev scalarization. If omitted, samples from the unit simplex.
63+
sampler: The sampler used to draw base samples. See `MCAcquisitionFunction`
64+
more details.
65+
objective: The MultiOutputMCAcquisitionObjective under which the samples are
66+
evaluated before applying Chebyshev scalarization.
67+
Defaults to `IdentityMultiOutputObjective()`.
68+
constraints: A list of constraint callables which map a Tensor of posterior
69+
samples of dimension `sample_shape x batch-shape x q x m'`-dim to a
70+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
71+
are satisfied if `constraint(samples) < 0`.
72+
X_pending: A `batch_shape x q' x d`-dim Tensor of `q'` design points
73+
that have points that have been submitted for function evaluation
74+
but have not yet been evaluated. Concatenated into `X` upon
75+
forward call. Copied and set to have no gradient.
76+
eta: Temperature parameter(s) governing the smoothness of the sigmoid
77+
approximation to the constraint indicators. See the docs of
78+
`compute_(log_)smoothed_constraint_indicator` for details.
79+
fat: Toggles the logarithmic / linear asymptotic behavior of the smooth
80+
approximation to the ReLU.
81+
prune_baseline: If True, remove points in `X_baseline` that are
82+
highly unlikely to be the best point. This can significantly
83+
improve performance and is generally recommended. In order to
84+
customize pruning parameters, instead manually call
85+
`botorch.acquisition.utils.prune_inferior_points` on `X_baseline`
86+
before instantiating the acquisition function.
87+
cache_root: A boolean indicating whether to cache the root
88+
decomposition over `X_baseline` and use low-rank updates.
89+
tau_max: Temperature parameter controlling the sharpness of the smooth
90+
approximations to max.
91+
tau_relu: Temperature parameter controlling the sharpness of the smooth
92+
approximations to ReLU.
93+
"""
94+
MultiObjectiveMCAcquisitionFunction.__init__(
95+
self,
96+
model=model,
97+
sampler=sampler,
98+
objective=objective,
99+
constraints=constraints,
100+
eta=eta,
101+
)
102+
org_objective = self.objective
103+
# Create the composite objective.
104+
with torch.no_grad():
105+
Y_baseline = org_objective(model.posterior(X_baseline).mean)
106+
if is_ensemble(model):
107+
Y_baseline = torch.mean(Y_baseline, dim=MCMC_DIM)
108+
scalarization_weights = (
109+
scalarization_weights
110+
if scalarization_weights is not None
111+
else sample_simplex(d=Y_baseline.shape[-1]).view(-1)
112+
)
113+
chebyshev_scalarization = get_chebyshev_scalarization(
114+
weights=scalarization_weights,
115+
Y=Y_baseline,
116+
)
117+
composite_objective = GenericMCObjective(
118+
objective=lambda samples, X=None: chebyshev_scalarization(
119+
org_objective(samples=samples, X=X), X=X
120+
),
121+
)
122+
qLogNoisyExpectedImprovement.__init__(
123+
self,
124+
model=model,
125+
X_baseline=X_baseline,
126+
sampler=sampler,
127+
objective=composite_objective,
128+
X_pending=X_pending,
129+
constraints=constraints,
130+
eta=eta,
131+
fat=fat,
132+
prune_baseline=prune_baseline,
133+
cache_root=cache_root,
134+
tau_max=tau_max,
135+
tau_relu=tau_relu,
136+
)
137+
# Set these after __init__ calls so that they're not overwritten / deleted.
138+
# These are intended mainly for easier debugging & transparency.
139+
self._org_objective: MCMultiOutputObjective = org_objective
140+
self.chebyshev_scalarization: Callable[[Tensor, Optional[Tensor]], Tensor] = (
141+
chebyshev_scalarization
142+
)
143+
self.scalarization_weights: Tensor = scalarization_weights
144+
self.Y_baseline: Tensor = Y_baseline
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Dict, Optional
7+
8+
import torch
9+
from botorch.acquisition.multi_objective.objective import (
10+
IdentityMCMultiOutputObjective,
11+
WeightedMCMultiOutputObjective,
12+
)
13+
from botorch.acquisition.multi_objective.parego import qLogNParEGO
14+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
15+
from botorch.models.gp_regression import SingleTaskGP
16+
from botorch.models.model import Model
17+
from botorch.models.model_list_gp_regression import ModelListGP
18+
from botorch.utils.testing import BotorchTestCase
19+
20+
21+
class TestqLogNParEGO(BotorchTestCase):
22+
def base_test_parego(
23+
self,
24+
with_constraints: bool = False,
25+
with_scalarization_weights: bool = False,
26+
with_objective: bool = False,
27+
model: Optional[Model] = None,
28+
) -> None:
29+
if with_constraints:
30+
assert with_objective, "Objective must be specified if constraints are."
31+
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
32+
num_objectives = 2
33+
num_constraints = 1 if with_constraints else 0
34+
num_outputs = num_objectives + num_constraints
35+
model = model or SingleTaskGP(
36+
train_X=torch.rand(5, 2, **tkwargs),
37+
train_Y=torch.rand(5, num_outputs, **tkwargs),
38+
)
39+
scalarization_weights = (
40+
torch.rand(num_objectives, **tkwargs)
41+
if with_scalarization_weights
42+
else None
43+
)
44+
objective = (
45+
WeightedMCMultiOutputObjective(
46+
weights=torch.tensor([2.0, -0.5], **tkwargs), outcomes=[0, 1]
47+
)
48+
if with_objective
49+
else None
50+
)
51+
constraints = [lambda samples: samples[..., -1]] if with_constraints else None
52+
acqf = qLogNParEGO(
53+
model=model,
54+
X_baseline=torch.rand(3, 2, **tkwargs),
55+
scalarization_weights=scalarization_weights,
56+
objective=objective,
57+
constraints=constraints,
58+
prune_baseline=True,
59+
)
60+
self.assertEqual(acqf.Y_baseline.shape, torch.Size([3, 2]))
61+
# Scalarization weights should be set if given and sampled otherwise.
62+
if scalarization_weights is not None:
63+
self.assertIs(acqf.scalarization_weights, scalarization_weights)
64+
else:
65+
self.assertEqual(
66+
acqf.scalarization_weights.shape, torch.Size([num_objectives])
67+
)
68+
# Should sum to 1 since they're sampled from simplex.
69+
self.assertAlmostEqual(acqf.scalarization_weights.sum().item(), 1.0)
70+
# Original objective should default to identity.
71+
if with_objective:
72+
self.assertIs(acqf._org_objective, objective)
73+
else:
74+
self.assertIsInstance(acqf._org_objective, IdentityMCMultiOutputObjective)
75+
# Acqf objective should be the chebyshev scalarization compounded with objective.
76+
test_samples = torch.rand(32, 5, num_outputs, **tkwargs)
77+
expected_objective = acqf.chebyshev_scalarization(
78+
acqf._org_objective(test_samples)
79+
)
80+
self.assertEqual(expected_objective.shape, torch.Size([32, 5]))
81+
self.assertAllClose(acqf.objective(test_samples), expected_objective)
82+
# Evaluate the acquisition function.
83+
self.assertEqual(acqf(torch.rand(5, 2, **tkwargs)).shape, torch.Size([1]))
84+
self.assertEqual(acqf(torch.rand(32, 5, 2, **tkwargs)).shape, torch.Size([32]))
85+
86+
def test_parego_simple(self) -> None:
87+
self.base_test_parego()
88+
89+
def test_parego_with_constraints_objective_weights(self) -> None:
90+
self.base_test_parego(
91+
with_constraints=True, with_objective=True, with_scalarization_weights=True
92+
)
93+
94+
def test_parego_with_ensemble_model(self) -> None:
95+
tkwargs: Dict[str, Any] = {"device": self.device, "dtype": torch.double}
96+
models = []
97+
for _ in range(2):
98+
model = SaasFullyBayesianSingleTaskGP(
99+
train_X=torch.rand(5, 2, **tkwargs),
100+
train_Y=torch.randn(5, 1, **tkwargs),
101+
train_Yvar=torch.rand(5, 1, **tkwargs) * 0.05,
102+
)
103+
mcmc_samples = {
104+
"lengthscale": torch.rand(4, 1, 2, **tkwargs),
105+
"outputscale": torch.rand(4, **tkwargs),
106+
"mean": torch.randn(4, **tkwargs),
107+
}
108+
model.load_mcmc_samples(mcmc_samples)
109+
models.append(model)
110+
self.base_test_parego(model=ModelListGP(*models))

test/acquisition/test_input_constructors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
IdentityMCMultiOutputObjective,
7373
WeightedMCMultiOutputObjective,
7474
)
75+
from botorch.acquisition.multi_objective.parego import qLogNParEGO
7576
from botorch.acquisition.multi_objective.utils import get_default_partitioning_alpha
7677
from botorch.acquisition.objective import (
7778
ConstrainedMCObjective,
@@ -1161,6 +1162,26 @@ def _test_construct_inputs_qNEHVI(self, acqf_class: Type[AcquisitionFunction]):
11611162
)
11621163
self.assertEqual(kwargs["alpha"], 0.0)
11631164

1165+
def test_construct_inputs_qLogNParEGO(self) -> None:
1166+
# Focusing on the unique attributes since the rest are same as qLogNEI.
1167+
c = get_acqf_input_constructor(qLogNParEGO)
1168+
kwargs = c(model=mock.Mock(), training_data=self.blockX_blockY)
1169+
self.assertTrue(torch.equal(kwargs["X_baseline"], self.blockX_blockY[0].X))
1170+
self.assertIsNone(kwargs["scalarization_weights"])
1171+
self.assertIsNone(kwargs["objective"])
1172+
self.assertNotIn("posterior_transform", kwargs)
1173+
# With custom objective & weights.
1174+
kwargs = c(
1175+
model=mock.Mock(),
1176+
training_data=self.blockX_blockY,
1177+
scalarization_weights=torch.zeros(2),
1178+
objective=IdentityMCMultiOutputObjective(outcomes=[0, 1]),
1179+
)
1180+
self.assertAllClose(kwargs["scalarization_weights"], torch.zeros(2))
1181+
self.assertIsInstance(kwargs["objective"], IdentityMCMultiOutputObjective)
1182+
1183+
1184+
class TestKGandESAcquisitionFunctionInputConstructors(InputConstructorBaseTestCase):
11641185
def test_construct_inputs_kg(self) -> None:
11651186
current_value = torch.tensor(1.23)
11661187
with mock.patch(
@@ -1386,6 +1407,7 @@ def test_constructors_like_qNEHVI(self) -> None:
13861407
ExpectedHypervolumeImprovement,
13871408
qExpectedHypervolumeImprovement,
13881409
qLogExpectedHypervolumeImprovement,
1410+
qLogNParEGO,
13891411
]
13901412
self._test_constructor_base(
13911413
classes=classes,

0 commit comments

Comments
 (0)