7
7
from itertools import product
8
8
from unittest import mock
9
9
10
+ import numpy as np
11
+
10
12
import torch
11
13
from botorch .acquisition .cost_aware import InverseCostWeightedUtility
12
14
from botorch .acquisition .multi_objective .hypervolume_knowledge_gradient import (
15
17
qHypervolumeKnowledgeGradient ,
16
18
qMultiFidelityHypervolumeKnowledgeGradient ,
17
19
)
18
- from botorch .acquisition .multi_objective .objective import GenericMCMultiOutputObjective
20
+ from botorch .acquisition .multi_objective .objective import (
21
+ GenericMCMultiOutputObjective ,
22
+ IdentityMCMultiOutputObjective ,
23
+ )
19
24
from botorch .exceptions .errors import UnsupportedError
20
25
from botorch .models .deterministic import GenericDeterministicModel
21
26
from botorch .models .gp_regression import SingleTaskGP
@@ -91,7 +96,7 @@ def test_initialization(self):
91
96
self .assertEqual (acqf .inner_sampler .sample_shape , torch .Size ([32 ]))
92
97
self .assertIsNone (acqf ._cost_sampler )
93
98
# test objective
94
- mc_objective = GenericMCMultiOutputObjective (lambda Y : 2 * Y )
99
+ mc_objective = GenericMCMultiOutputObjective (lambda Y , X : 2 * Y )
95
100
acqf = acqf_class (
96
101
model = model , ref_point = ref_point , objective = mc_objective , ** mf_kwargs
97
102
)
@@ -171,8 +176,8 @@ def test_evaluate_q_hvkg(self):
171
176
tkwargs ["dtype" ] = dtype
172
177
# basic test
173
178
n_f = 4
174
- mean = torch .rand (n_f , num_pareto , 2 , ** tkwargs )
175
- variance = torch .rand (n_f , num_pareto , 2 , ** tkwargs )
179
+ mean = torch .rand (n_f , 1 , num_pareto , 2 , ** tkwargs )
180
+ variance = torch .rand (n_f , 1 , num_pareto , 2 , ** tkwargs )
176
181
mfm = MockModel (MockPosterior (mean = mean , variance = variance ))
177
182
ref_point = torch .zeros (2 , ** tkwargs )
178
183
models = [
@@ -204,11 +209,11 @@ def test_evaluate_q_hvkg(self):
204
209
cargs , ckwargs = patch_f .call_args
205
210
self .assertEqual (ckwargs ["X" ].shape , torch .Size ([1 , 1 , 1 ]))
206
211
expected_hv = (
207
- DominatedPartitioning (Y = mean , ref_point = ref_point )
212
+ DominatedPartitioning (Y = mean . squeeze ( 1 ) , ref_point = ref_point )
208
213
.compute_hypervolume ()
209
214
.mean ()
210
215
)
211
- self .assertAllClose (val , expected_hv , atol = 1e-4 )
216
+ self .assertAllClose (val . item () , expected_hv . item () , atol = 1e-4 )
212
217
self .assertTrue (
213
218
torch .equal (qHVKG .extract_candidates (X ), X [..., : - n_f * num_pareto , :])
214
219
)
@@ -253,8 +258,8 @@ def test_evaluate_q_hvkg(self):
253
258
X_evaluation_mask = torch .tensor (
254
259
[[False , True ]], dtype = torch .bool , device = self .device
255
260
)
256
- mean = torch .rand (n_f , num_pareto , 2 , ** tkwargs )
257
- variance = torch .rand (n_f , num_pareto , 2 , ** tkwargs )
261
+ mean = torch .rand (n_f , 1 , num_pareto , 2 , ** tkwargs )
262
+ variance = torch .rand (n_f , 1 , num_pareto , 2 , ** tkwargs )
258
263
mfm = MockModel (MockPosterior (mean = mean , variance = variance ))
259
264
current_value = torch .tensor (0.0 , ** tkwargs )
260
265
X = torch .rand (n_f * num_pareto + 1 , 1 , ** tkwargs )
@@ -289,7 +294,7 @@ def test_evaluate_q_hvkg(self):
289
294
torch .equal (ckwargs ["evaluation_mask" ], expected_eval_mask )
290
295
)
291
296
expected_hv = (
292
- DominatedPartitioning (Y = mean , ref_point = ref_point )
297
+ DominatedPartitioning (Y = mean . squeeze ( 1 ) , ref_point = ref_point )
293
298
.compute_hypervolume ()
294
299
.mean (dim = 0 )
295
300
)
@@ -320,46 +325,10 @@ def test_evaluate_q_hvkg(self):
320
325
val = qHVKG (X )
321
326
self .assertEqual (val .item (), 0.0 )
322
327
323
- # test objective (inner MC sampling)
324
- objective = GenericMCMultiOutputObjective (lambda Y , X : 2 * Y )
325
- samples = torch .randn (n_f , 1 , num_pareto , 2 , ** tkwargs )
326
- mfm = MockModel (MockPosterior (samples = samples ))
327
- X = torch .rand (n_f * num_pareto + 1 , 1 , ** tkwargs )
328
- with mock .patch .object (
329
- ModelListGP , "fantasize" , return_value = mfm
330
- ) as patch_f :
331
- with mock .patch (NO , new_callable = mock .PropertyMock ) as mock_num_outputs :
332
- mock_num_outputs .return_value = 2
333
- qHVKG = acqf_class (
334
- model = model ,
335
- num_fantasies = n_f ,
336
- objective = objective ,
337
- ref_point = ref_point ,
338
- num_pareto = num_pareto ,
339
- use_posterior_mean = False ,
340
- ** mf_kwargs ,
341
- )
342
- val = qHVKG (X )
343
- patch_f .assert_called_once ()
344
- cargs , ckwargs = patch_f .call_args
345
- self .assertEqual (ckwargs ["X" ].shape , torch .Size ([1 , 1 , 1 ]))
346
- expected_hv = (
347
- DominatedPartitioning (
348
- Y = objective (samples ).view (- 1 , num_pareto , 2 ), ref_point = ref_point
349
- )
350
- .compute_hypervolume ()
351
- .view (n_f , 1 )
352
- .mean (dim = 0 )
353
- )
354
- self .assertAllClose (val , expected_hv , atol = 1e-4 )
355
- self .assertTrue (
356
- torch .equal (qHVKG .extract_candidates (X ), X [..., : - n_f * num_pareto , :])
357
- )
358
-
359
328
# test mfkg
360
329
if acqf_class == qMultiFidelityHypervolumeKnowledgeGradient :
361
- mean = torch .rand (n_f , num_pareto , 2 , ** tkwargs )
362
- variance = torch .rand (n_f , num_pareto , 2 , ** tkwargs )
330
+ mean = torch .rand (n_f , 1 , num_pareto , 2 , ** tkwargs )
331
+ variance = torch .rand (n_f , 1 , num_pareto , 2 , ** tkwargs )
363
332
mfm = MockModel (MockPosterior (mean = mean , variance = variance ))
364
333
current_value = torch .rand (1 , ** tkwargs )
365
334
X = torch .rand (n_f * num_pareto + 1 , 1 , ** tkwargs )
@@ -388,6 +357,81 @@ def test_evaluate_q_hvkg(self):
388
357
mock_get_value_func .call_args_list [0 ][1 ]["project" ]
389
358
)
390
359
360
+ # test objective (inner MC sampling)
361
+ mean = torch .rand (n_f , 1 , num_pareto , 3 , ** tkwargs )
362
+ samples = mean + 1
363
+ variance = torch .rand (n_f , 1 , num_pareto , 3 , ** tkwargs )
364
+ mfm = MockModel (
365
+ MockPosterior (mean = mean , variance = variance , samples = samples )
366
+ )
367
+ models = [
368
+ SingleTaskGP (torch .rand (2 , 1 , ** tkwargs ), torch .rand (2 , 1 , ** tkwargs )),
369
+ SingleTaskGP (torch .rand (4 , 1 , ** tkwargs ), torch .rand (4 , 1 , ** tkwargs )),
370
+ SingleTaskGP (torch .rand (5 , 1 , ** tkwargs ), torch .rand (5 , 1 , ** tkwargs )),
371
+ ]
372
+ model = ModelListGP (* models )
373
+ for num_objectives in (2 , 3 ):
374
+ # test using 1) a botorch objective that only uses 2 out of
375
+ # 3 outcomes as objectives, 2) a botorch objective that uses
376
+ # all 3 outcomes as objectives
377
+ objective = (
378
+ IdentityMCMultiOutputObjective (outcomes = [0 , 1 ])
379
+ if num_objectives == 2
380
+ else GenericMCMultiOutputObjective (lambda Y , X : 2 * Y )
381
+ )
382
+
383
+ ref_point = torch .zeros (num_objectives , ** tkwargs )
384
+ X = torch .rand (n_f * num_pareto + 1 , 1 , ** tkwargs )
385
+
386
+ for use_posterior_mean in (True , False ):
387
+ with mock .patch .object (
388
+ ModelListGP , "fantasize" , return_value = mfm
389
+ ) as patch_f :
390
+ with mock .patch (
391
+ NO , new_callable = mock .PropertyMock
392
+ ) as mock_num_outputs :
393
+ mock_num_outputs .return_value = 3
394
+ qHVKG = acqf_class (
395
+ model = model ,
396
+ num_fantasies = n_f ,
397
+ objective = objective ,
398
+ ref_point = ref_point ,
399
+ num_pareto = num_pareto ,
400
+ use_posterior_mean = use_posterior_mean ,
401
+ ** mf_kwargs ,
402
+ )
403
+ val = qHVKG (X )
404
+ patch_f .assert_called_once ()
405
+ cargs , ckwargs = patch_f .call_args
406
+ self .assertEqual (ckwargs ["X" ].shape , torch .Size ([1 , 1 , 1 ]))
407
+ Ys = mean if use_posterior_mean else samples
408
+ objs = objective (Ys .squeeze (1 )).view (- 1 , num_pareto , num_objectives )
409
+ if num_objectives == 2 :
410
+ expected_hv = (
411
+ DominatedPartitioning (Y = objs , ref_point = ref_point )
412
+ .compute_hypervolume ()
413
+ .mean ()
414
+ .item ()
415
+ )
416
+ else :
417
+ # batch box decomposition don't support > 2 objectives
418
+ objs = objective (Ys ).view (- 1 , num_pareto , num_objectives )
419
+ expected_hv = np .mean (
420
+ [
421
+ DominatedPartitioning (Y = obj , ref_point = ref_point )
422
+ .compute_hypervolume ()
423
+ .mean ()
424
+ .item ()
425
+ for obj in objs
426
+ ]
427
+ )
428
+ self .assertAllClose (val .item (), expected_hv , atol = 1e-4 )
429
+ self .assertTrue (
430
+ torch .equal (
431
+ qHVKG .extract_candidates (X ), X [..., : - n_f * num_pareto , :]
432
+ )
433
+ )
434
+
391
435
def test_split_hvkg_fantasy_points (self ):
392
436
d = 4
393
437
for dtype , batch_shape , n_f , num_pareto , q in product (
@@ -410,8 +454,8 @@ def test_split_hvkg_fantasy_points(self):
410
454
n_f = 100
411
455
num_pareto = 3
412
456
msg = (
413
- f"`n_f\*num_pareto` \({ n_f * num_pareto } \) must be less than" # noqa: W605
414
- f " the `q`-batch dimension of `X` \({ X .size (- 2 )} \)\." # noqa: W605
457
+ rf".* \({ n_f * num_pareto } \) must be less than"
458
+ rf " the `q`-batch dimension of `X` \({ X .size (- 2 )} \)\."
415
459
)
416
460
with self .assertRaisesRegex (ValueError , msg ):
417
461
_split_hvkg_fantasy_points (X = X , n_f = n_f , num_pareto = num_pareto )
0 commit comments