16
16
17
17
"""Tests for gp_bandit."""
18
18
19
+ from typing import Callable
19
20
from unittest import mock
20
21
21
22
import jax
@@ -47,6 +48,69 @@ def _build_mock_continuous_array_specs(n):
47
48
return [continuous_spec ] * n
48
49
49
50
51
+ def _setup_lambda_search (
52
+ f : Callable [[float ], float ], num_trials : int = 100
53
+ ) -> tuple [gp_bandit .VizierGPBandit , list [vz .Trial ], vz .ProblemStatement ]:
54
+ """Sets up a GP designer and outputs completed studies for `f`.
55
+
56
+ Args:
57
+ f: 1D objective to be optimized, i.e. f(x), where x is a scalar in [-5., 5.)
58
+ num_trials: Number of mock "evaluated" trials to return.
59
+
60
+ Returns:
61
+ A GP designer set up for the problem of optimizing the objective, without any
62
+ data updated.
63
+ Evaluated trials against `f`.
64
+ """
65
+ assert (
66
+ num_trials > 0
67
+ ), f'Must provide a positive number of trials. Got { num_trials } .'
68
+
69
+ search_space = vz .SearchSpace ()
70
+ search_space .root .add_float_param ('x0' , - 5.0 , 5.0 )
71
+ problem = vz .ProblemStatement (
72
+ search_space = search_space ,
73
+ metric_information = vz .MetricsConfig (
74
+ metrics = [
75
+ vz .MetricInformation ('obj' , goal = vz .ObjectiveMetricGoal .MAXIMIZE ),
76
+ ]
77
+ ),
78
+ )
79
+
80
+ suggestions = quasi_random .QuasiRandomDesigner (
81
+ problem .search_space , seed = 1
82
+ ).suggest (num_trials )
83
+
84
+ obs_trials = []
85
+ for idx , suggestion in enumerate (suggestions ):
86
+ trial = suggestion .to_trial (idx )
87
+ x = suggestions [idx ].parameters ['x0' ].value
88
+ trial .complete (vz .Measurement (metrics = {'obj' : f (x )}))
89
+ obs_trials .append (trial )
90
+
91
+ gp_designer = gp_bandit .VizierGPBandit (problem , ard_optimizer = ard_optimizer )
92
+ return gp_designer , obs_trials , problem
93
+
94
+
95
+ def _compute_mse (
96
+ designer : gp_bandit .VizierGPBandit ,
97
+ test_trials : list [vz .Trial ],
98
+ y_test : list [float ],
99
+ ) -> float :
100
+ """Evaluate the designer's accuracy on the test set.
101
+
102
+ Args:
103
+ designer: The GP bandit designer to predict from.
104
+ test_trials: The trials of the test set
105
+ y_test: The results of the test set
106
+
107
+ Returns:
108
+ The MSE of `designer` on `test_trials` and `y_test`
109
+ """
110
+ preds = designer .predict (test_trials )
111
+ return np .sum (np .square (preds .mean - y_test ))
112
+
113
+
50
114
class GoogleGpBanditTest (parameterized .TestCase ):
51
115
52
116
@parameterized .parameters (
@@ -216,32 +280,8 @@ def test_on_flat_mixed_space(
216
280
self .assertFalse (np .isnan (prediction .stddev ).any ())
217
281
218
282
def test_prediction_accuracy (self ):
219
- search_space = vz .SearchSpace ()
220
- search_space .root .add_float_param ('x0' , - 5.0 , 5.0 )
221
- problem = vz .ProblemStatement (
222
- search_space = search_space ,
223
- metric_information = vz .MetricsConfig (
224
- metrics = [
225
- vz .MetricInformation (
226
- 'obj' , goal = vz .ObjectiveMetricGoal .MAXIMIZE
227
- ),
228
- ]
229
- ),
230
- )
231
283
f = lambda x : - ((x - 0.5 ) ** 2 )
232
-
233
- suggestions = quasi_random .QuasiRandomDesigner (
234
- problem .search_space , seed = 1
235
- ).suggest (100 )
236
-
237
- obs_trials = []
238
- for idx , suggestion in enumerate (suggestions ):
239
- trial = suggestion .to_trial (idx )
240
- x = suggestions [idx ].parameters ['x0' ].value
241
- trial .complete (vz .Measurement (metrics = {'obj' : f (x )}))
242
- obs_trials .append (trial )
243
-
244
- gp_designer = gp_bandit .VizierGPBandit (problem , ard_optimizer = ard_optimizer )
284
+ gp_designer , obs_trials , _ = _setup_lambda_search (f )
245
285
gp_designer .update (vza .CompletedTrials (obs_trials ), vza .ActiveTrials ())
246
286
pred_trial = vz .Trial ({'x0' : 0.0 })
247
287
pred = gp_designer .predict ([pred_trial ])
@@ -261,6 +301,7 @@ def test_jit_once(self, *args):
261
301
name = 'metric' , goal = vz .ObjectiveMetricGoal .MAXIMIZE
262
302
)
263
303
)
304
+
264
305
def create_designer (problem ):
265
306
return gp_bandit .VizierGPBandit (
266
307
problem = problem ,
@@ -299,6 +340,83 @@ def create_runner(problem):
299
340
create_runner (problem ).run_designer (designer2 )
300
341
301
342
343
+ class GPBanditPriorsTest (parameterized .TestCase ):
344
+
345
+ def test_prior_warping (self ):
346
+ """Tests linear transform of objective has no impact on transfer learning."""
347
+ f = lambda x : - ((x - 0.5 ) ** 2 )
348
+ transform_f = lambda x : - 3 * ((x - 0.5 ) ** 2 ) + 10
349
+
350
+ # X is in range of what is defined in `_setup_lambda_search`, [-5.0, 5.0)
351
+ x_test = np .random .default_rng (1 ).uniform (- 5.0 , 5.0 , 100 )
352
+ y_test = [transform_f (x ) for x in x_test ]
353
+ test_trials = [vz .Trial ({'x0' : x }) for x in x_test ]
354
+
355
+ # Create the designer with a prior and the trials to train the prior.
356
+ gp_designer_with_prior , obs_trials_for_prior , _ = _setup_lambda_search (
357
+ f = f , num_trials = 100
358
+ )
359
+
360
+ # Set priors to above trials.
361
+ gp_designer_with_prior .set_priors (
362
+ [vza .CompletedTrials (obs_trials_for_prior )]
363
+ )
364
+
365
+ # Create a no prior designer on the transformed function `transform_f`.
366
+ # Also use the generated trials to update both the designer with prior and
367
+ # the designer without. This tests that the prior designer is resilient
368
+ # to linear transforms between the prior and the top level study.
369
+ gp_designer_no_prior , obs_trials , _ = _setup_lambda_search (
370
+ f = transform_f , num_trials = 20
371
+ )
372
+
373
+ # Update both designers with the actual study.
374
+ gp_designer_no_prior .update (
375
+ vza .CompletedTrials (obs_trials ), vza .ActiveTrials ()
376
+ )
377
+ gp_designer_with_prior .update (
378
+ vza .CompletedTrials (obs_trials ), vza .ActiveTrials ()
379
+ )
380
+
381
+ # Evaluate the no prior designer's accuracy on the test set.
382
+ mse_no_prior = _compute_mse (gp_designer_no_prior , test_trials , y_test )
383
+
384
+ # Evaluate the designer with prior's accuracy on the test set.
385
+ mse_with_prior = _compute_mse (gp_designer_with_prior , test_trials , y_test )
386
+
387
+ # The designer with a prior should predict better.
388
+ self .assertLess (mse_with_prior , mse_no_prior )
389
+
390
+ @parameterized .parameters (
391
+ dict (iters = 3 , batch_size = 5 ),
392
+ dict (iters = 5 , batch_size = 1 ),
393
+ )
394
+ def test_run_with_priors (self , * , iters , batch_size ):
395
+ f = lambda x : - ((x - 0.5 ) ** 2 )
396
+
397
+ # Create the designer with a prior and the trials to train the prior.
398
+ gp_designer_with_prior , obs_trials_for_prior , problem = (
399
+ _setup_lambda_search (f = f , num_trials = 100 )
400
+ )
401
+
402
+ # Set priors to the above trials.
403
+ gp_designer_with_prior .set_priors (
404
+ [vza .CompletedTrials (obs_trials_for_prior )]
405
+ )
406
+
407
+ self .assertLen (
408
+ test_runners .RandomMetricsRunner (
409
+ problem ,
410
+ iters = iters ,
411
+ batch_size = batch_size ,
412
+ verbose = 1 ,
413
+ validate_parameters = True ,
414
+ seed = 1 ,
415
+ ).run_designer (gp_designer_with_prior ),
416
+ iters * batch_size ,
417
+ )
418
+
419
+
302
420
if __name__ == '__main__' :
303
421
# Jax disables float64 computations by default and will silently convert
304
422
# float64s to float32s. We must explicitly enable float64.
0 commit comments