1313from botorch .models import SingleTaskGP
1414from botorch .models .fully_bayesian import SaasFullyBayesianSingleTaskGP
1515from botorch .models .transforms .outcome import Standardize
16+ from botorch .sampling .normal import IIDNormalSampler
1617from botorch .utils .testing import BotorchTestCase
1718
1819
20+ def get_model (
21+ train_X ,
22+ train_Y ,
23+ standardize_model ,
24+ ** tkwargs ,
25+ ):
26+ num_objectives = train_Y .shape [- 1 ]
27+
28+ if standardize_model :
29+ outcome_transform = Standardize (m = num_objectives )
30+ else :
31+ outcome_transform = None
32+
33+ model = SingleTaskGP (
34+ train_X = train_X ,
35+ train_Y = train_Y ,
36+ outcome_transform = outcome_transform ,
37+ )
38+
39+ return model
40+
41+
1942def _get_mcmc_samples (num_samples : int , dim : int , infer_noise : bool , ** tkwargs ):
2043
2144 mcmc_samples = {
@@ -28,7 +51,7 @@ def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
2851 return mcmc_samples
2952
3053
31- def get_model (
54+ def get_fully_bayesian_model (
3255 train_X ,
3356 train_Y ,
3457 num_models ,
@@ -72,21 +95,26 @@ def test_q_bayesian_active_learning_by_disagreement(self):
7295 tkwargs = {"device" : self .device }
7396 num_objectives = 1
7497 num_models = 3
98+ input_dim = 2
99+
100+ X_pending_list = [None , torch .rand (2 , input_dim )]
75101 for (
76102 dtype ,
77103 standardize_model ,
78104 infer_noise ,
105+ X_pending ,
79106 ) in product (
80107 (torch .float , torch .double ),
81108 (False , True ), # standardize_model
82109 (True ,), # infer_noise - only one option avail in PyroModels
110+ X_pending_list ,
83111 ):
112+ X_pending = X_pending .to (** tkwargs ) if X_pending is not None else None
84113 tkwargs ["dtype" ] = dtype
85- input_dim = 2
86114 train_X = torch .rand (4 , input_dim , ** tkwargs )
87115 train_Y = torch .rand (4 , num_objectives , ** tkwargs )
88116
89- model = get_model (
117+ model = get_fully_bayesian_model (
90118 train_X ,
91119 train_Y ,
92120 num_models ,
@@ -96,32 +124,40 @@ def test_q_bayesian_active_learning_by_disagreement(self):
96124 )
97125
98126 # test acquisition
99- X_pending_list = [None , torch .rand (2 , input_dim , ** tkwargs )]
100- for i in range (len (X_pending_list )):
101- X_pending = X_pending_list [i ]
102-
103- acq = qBayesianActiveLearningByDisagreement (
104- model = model ,
105- X_pending = X_pending ,
106- )
107-
108- test_Xs = [
109- torch .rand (4 , 1 , input_dim , ** tkwargs ),
110- torch .rand (4 , 3 , input_dim , ** tkwargs ),
111- torch .rand (4 , 5 , 1 , input_dim , ** tkwargs ),
112- torch .rand (4 , 5 , 3 , input_dim , ** tkwargs ),
113- ]
114-
115- for j in range (len (test_Xs )):
116- acq_X = acq .forward (test_Xs [j ])
117- acq_X = acq (test_Xs [j ])
118- # assess shape
119- self .assertTrue (acq_X .shape == test_Xs [j ].shape [:- 2 ])
127+ acq = qBayesianActiveLearningByDisagreement (
128+ model = model ,
129+ X_pending = X_pending ,
130+ )
131+
132+ acq2 = qBayesianActiveLearningByDisagreement (
133+ model = model , sampler = IIDNormalSampler (torch .Size ([9 ]))
134+ )
135+ self .assertIsInstance (acq2 .sampler , IIDNormalSampler )
136+
137+ test_Xs = [
138+ torch .rand (4 , 1 , input_dim , ** tkwargs ),
139+ torch .rand (4 , 3 , input_dim , ** tkwargs ),
140+ torch .rand (4 , 5 , 1 , input_dim , ** tkwargs ),
141+ torch .rand (4 , 5 , 3 , input_dim , ** tkwargs ),
142+ torch .rand (5 , 13 , input_dim , ** tkwargs ),
143+ ]
144+
145+ for j in range (len (test_Xs )):
146+ acq_X = acq .forward (test_Xs [j ])
147+ acq_X = acq (test_Xs [j ])
148+ # assess shape
149+ self .assertTrue (acq_X .shape == test_Xs [j ].shape [:- 2 ])
150+
151+ self .assertTrue (torch .all (acq_X > 0 ))
120152
121153 # Support with non-fully bayesian models is not possible. Thus, we
122154 # throw an error.
123- non_fully_bayesian_model = SingleTaskGP (train_X , train_Y )
124- with self .assertRaises (ValueError ):
155+ non_fully_bayesian_model = get_model (train_X , train_Y , False )
156+ with self .assertRaisesRegex (
157+ ValueError ,
158+ "Fully Bayesian acquisition functions require a "
159+ "SaasFullyBayesianSingleTaskGP to run." ,
160+ ):
125161 acq = qBayesianActiveLearningByDisagreement (
126162 model = non_fully_bayesian_model ,
127163 )
0 commit comments