Skip to content

Commit 967535f

Browse files
sdaultonfacebook-github-bot
authored andcommitted
set HVKG sampler to reflect the number of model outputs (#2160)
Summary: Pull Request resolved: #2160 This handles the case where the number of objectives is not the same as the number of model outputs. #2159 Reviewed By: Balandat Differential Revision: D52418477 fbshipit-source-id: a987d4be973eefad58a79ce08a6fcd370cc03bfe
1 parent b2af19c commit 967535f

File tree

2 files changed

+96
-58
lines changed

2 files changed

+96
-58
lines changed

botorch/acquisition/multi_objective/hypervolume_knowledge_gradient.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,8 @@ def __init__(
131131
if sampler is None:
132132
# base samples should be fixed for joint optimization over X, X_fantasies
133133
samplers = [
134-
SobolQMCNormalSampler(
135-
sample_shape=torch.Size([num_fantasies]),
136-
resample=False,
137-
collapse_batch_dims=True,
138-
)
139-
for _ in range(ref_point.shape[0])
134+
SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
135+
for _ in range(model.num_outputs)
140136
]
141137
sampler = ListSampler(*samplers)
142138
else:
@@ -148,9 +144,7 @@ def __init__(
148144
super().__init__(model=model, X_evaluation_mask=X_evaluation_mask)
149145

150146
if inner_sampler is None:
151-
inner_sampler = SobolQMCNormalSampler(
152-
sample_shape=torch.Size([32]), resample=False, collapse_batch_dims=True
153-
)
147+
inner_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([32]))
154148
if current_value is None and cost_aware_utility is not None:
155149
raise UnsupportedError(
156150
"Cost-aware HVKG requires current_value to be specified."

test/acquisition/multi_objective/test_hypervolume_knowledge_gradient.py

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from itertools import product
88
from unittest import mock
99

10+
import numpy as np
11+
1012
import torch
1113
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
1214
from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import (
@@ -15,7 +17,10 @@
1517
qHypervolumeKnowledgeGradient,
1618
qMultiFidelityHypervolumeKnowledgeGradient,
1719
)
18-
from botorch.acquisition.multi_objective.objective import GenericMCMultiOutputObjective
20+
from botorch.acquisition.multi_objective.objective import (
21+
GenericMCMultiOutputObjective,
22+
IdentityMCMultiOutputObjective,
23+
)
1924
from botorch.exceptions.errors import UnsupportedError
2025
from botorch.models.deterministic import GenericDeterministicModel
2126
from botorch.models.gp_regression import SingleTaskGP
@@ -91,7 +96,7 @@ def test_initialization(self):
9196
self.assertEqual(acqf.inner_sampler.sample_shape, torch.Size([32]))
9297
self.assertIsNone(acqf._cost_sampler)
9398
# test objective
94-
mc_objective = GenericMCMultiOutputObjective(lambda Y: 2 * Y)
99+
mc_objective = GenericMCMultiOutputObjective(lambda Y, X: 2 * Y)
95100
acqf = acqf_class(
96101
model=model, ref_point=ref_point, objective=mc_objective, **mf_kwargs
97102
)
@@ -171,8 +176,8 @@ def test_evaluate_q_hvkg(self):
171176
tkwargs["dtype"] = dtype
172177
# basic test
173178
n_f = 4
174-
mean = torch.rand(n_f, num_pareto, 2, **tkwargs)
175-
variance = torch.rand(n_f, num_pareto, 2, **tkwargs)
179+
mean = torch.rand(n_f, 1, num_pareto, 2, **tkwargs)
180+
variance = torch.rand(n_f, 1, num_pareto, 2, **tkwargs)
176181
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
177182
ref_point = torch.zeros(2, **tkwargs)
178183
models = [
@@ -204,11 +209,11 @@ def test_evaluate_q_hvkg(self):
204209
cargs, ckwargs = patch_f.call_args
205210
self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1]))
206211
expected_hv = (
207-
DominatedPartitioning(Y=mean, ref_point=ref_point)
212+
DominatedPartitioning(Y=mean.squeeze(1), ref_point=ref_point)
208213
.compute_hypervolume()
209214
.mean()
210215
)
211-
self.assertAllClose(val, expected_hv, atol=1e-4)
216+
self.assertAllClose(val.item(), expected_hv.item(), atol=1e-4)
212217
self.assertTrue(
213218
torch.equal(qHVKG.extract_candidates(X), X[..., : -n_f * num_pareto, :])
214219
)
@@ -253,8 +258,8 @@ def test_evaluate_q_hvkg(self):
253258
X_evaluation_mask = torch.tensor(
254259
[[False, True]], dtype=torch.bool, device=self.device
255260
)
256-
mean = torch.rand(n_f, num_pareto, 2, **tkwargs)
257-
variance = torch.rand(n_f, num_pareto, 2, **tkwargs)
261+
mean = torch.rand(n_f, 1, num_pareto, 2, **tkwargs)
262+
variance = torch.rand(n_f, 1, num_pareto, 2, **tkwargs)
258263
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
259264
current_value = torch.tensor(0.0, **tkwargs)
260265
X = torch.rand(n_f * num_pareto + 1, 1, **tkwargs)
@@ -289,7 +294,7 @@ def test_evaluate_q_hvkg(self):
289294
torch.equal(ckwargs["evaluation_mask"], expected_eval_mask)
290295
)
291296
expected_hv = (
292-
DominatedPartitioning(Y=mean, ref_point=ref_point)
297+
DominatedPartitioning(Y=mean.squeeze(1), ref_point=ref_point)
293298
.compute_hypervolume()
294299
.mean(dim=0)
295300
)
@@ -320,46 +325,10 @@ def test_evaluate_q_hvkg(self):
320325
val = qHVKG(X)
321326
self.assertEqual(val.item(), 0.0)
322327

323-
# test objective (inner MC sampling)
324-
objective = GenericMCMultiOutputObjective(lambda Y, X: 2 * Y)
325-
samples = torch.randn(n_f, 1, num_pareto, 2, **tkwargs)
326-
mfm = MockModel(MockPosterior(samples=samples))
327-
X = torch.rand(n_f * num_pareto + 1, 1, **tkwargs)
328-
with mock.patch.object(
329-
ModelListGP, "fantasize", return_value=mfm
330-
) as patch_f:
331-
with mock.patch(NO, new_callable=mock.PropertyMock) as mock_num_outputs:
332-
mock_num_outputs.return_value = 2
333-
qHVKG = acqf_class(
334-
model=model,
335-
num_fantasies=n_f,
336-
objective=objective,
337-
ref_point=ref_point,
338-
num_pareto=num_pareto,
339-
use_posterior_mean=False,
340-
**mf_kwargs,
341-
)
342-
val = qHVKG(X)
343-
patch_f.assert_called_once()
344-
cargs, ckwargs = patch_f.call_args
345-
self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1]))
346-
expected_hv = (
347-
DominatedPartitioning(
348-
Y=objective(samples).view(-1, num_pareto, 2), ref_point=ref_point
349-
)
350-
.compute_hypervolume()
351-
.view(n_f, 1)
352-
.mean(dim=0)
353-
)
354-
self.assertAllClose(val, expected_hv, atol=1e-4)
355-
self.assertTrue(
356-
torch.equal(qHVKG.extract_candidates(X), X[..., : -n_f * num_pareto, :])
357-
)
358-
359328
# test mfkg
360329
if acqf_class == qMultiFidelityHypervolumeKnowledgeGradient:
361-
mean = torch.rand(n_f, num_pareto, 2, **tkwargs)
362-
variance = torch.rand(n_f, num_pareto, 2, **tkwargs)
330+
mean = torch.rand(n_f, 1, num_pareto, 2, **tkwargs)
331+
variance = torch.rand(n_f, 1, num_pareto, 2, **tkwargs)
363332
mfm = MockModel(MockPosterior(mean=mean, variance=variance))
364333
current_value = torch.rand(1, **tkwargs)
365334
X = torch.rand(n_f * num_pareto + 1, 1, **tkwargs)
@@ -388,6 +357,81 @@ def test_evaluate_q_hvkg(self):
388357
mock_get_value_func.call_args_list[0][1]["project"]
389358
)
390359

360+
# test objective (inner MC sampling)
361+
mean = torch.rand(n_f, 1, num_pareto, 3, **tkwargs)
362+
samples = mean + 1
363+
variance = torch.rand(n_f, 1, num_pareto, 3, **tkwargs)
364+
mfm = MockModel(
365+
MockPosterior(mean=mean, variance=variance, samples=samples)
366+
)
367+
models = [
368+
SingleTaskGP(torch.rand(2, 1, **tkwargs), torch.rand(2, 1, **tkwargs)),
369+
SingleTaskGP(torch.rand(4, 1, **tkwargs), torch.rand(4, 1, **tkwargs)),
370+
SingleTaskGP(torch.rand(5, 1, **tkwargs), torch.rand(5, 1, **tkwargs)),
371+
]
372+
model = ModelListGP(*models)
373+
for num_objectives in (2, 3):
374+
# test using 1) a botorch objective that only uses 2 out of
375+
# 3 outcomes as objectives, 2) a botorch objective that uses
376+
# all 3 outcomes as objectives
377+
objective = (
378+
IdentityMCMultiOutputObjective(outcomes=[0, 1])
379+
if num_objectives == 2
380+
else GenericMCMultiOutputObjective(lambda Y, X: 2 * Y)
381+
)
382+
383+
ref_point = torch.zeros(num_objectives, **tkwargs)
384+
X = torch.rand(n_f * num_pareto + 1, 1, **tkwargs)
385+
386+
for use_posterior_mean in (True, False):
387+
with mock.patch.object(
388+
ModelListGP, "fantasize", return_value=mfm
389+
) as patch_f:
390+
with mock.patch(
391+
NO, new_callable=mock.PropertyMock
392+
) as mock_num_outputs:
393+
mock_num_outputs.return_value = 3
394+
qHVKG = acqf_class(
395+
model=model,
396+
num_fantasies=n_f,
397+
objective=objective,
398+
ref_point=ref_point,
399+
num_pareto=num_pareto,
400+
use_posterior_mean=use_posterior_mean,
401+
**mf_kwargs,
402+
)
403+
val = qHVKG(X)
404+
patch_f.assert_called_once()
405+
cargs, ckwargs = patch_f.call_args
406+
self.assertEqual(ckwargs["X"].shape, torch.Size([1, 1, 1]))
407+
Ys = mean if use_posterior_mean else samples
408+
objs = objective(Ys.squeeze(1)).view(-1, num_pareto, num_objectives)
409+
if num_objectives == 2:
410+
expected_hv = (
411+
DominatedPartitioning(Y=objs, ref_point=ref_point)
412+
.compute_hypervolume()
413+
.mean()
414+
.item()
415+
)
416+
else:
417+
# batch box decomposition don't support > 2 objectives
418+
objs = objective(Ys).view(-1, num_pareto, num_objectives)
419+
expected_hv = np.mean(
420+
[
421+
DominatedPartitioning(Y=obj, ref_point=ref_point)
422+
.compute_hypervolume()
423+
.mean()
424+
.item()
425+
for obj in objs
426+
]
427+
)
428+
self.assertAllClose(val.item(), expected_hv, atol=1e-4)
429+
self.assertTrue(
430+
torch.equal(
431+
qHVKG.extract_candidates(X), X[..., : -n_f * num_pareto, :]
432+
)
433+
)
434+
391435
def test_split_hvkg_fantasy_points(self):
392436
d = 4
393437
for dtype, batch_shape, n_f, num_pareto, q in product(
@@ -410,8 +454,8 @@ def test_split_hvkg_fantasy_points(self):
410454
n_f = 100
411455
num_pareto = 3
412456
msg = (
413-
f"`n_f\*num_pareto` \({n_f*num_pareto}\) must be less than" # noqa: W605
414-
f" the `q`-batch dimension of `X` \({X.size(-2)}\)\." # noqa: W605
457+
rf".*\({n_f*num_pareto}\) must be less than"
458+
rf" the `q`-batch dimension of `X` \({X.size(-2)}\)\."
415459
)
416460
with self.assertRaisesRegex(ValueError, msg):
417461
_split_hvkg_fantasy_points(X=X, n_f=n_f, num_pareto=num_pareto)

0 commit comments

Comments
 (0)