Skip to content

Commit 35ba80e

Browse files
vizier-teamcopybara-github
vizier-team
authored andcommitted
[OSS Vizier] Leverage sequential prior transfer learning in gp_bandit.py.
PiperOrigin-RevId: 553912121
1 parent 7465af6 commit 35ba80e

File tree

3 files changed

+232
-44
lines changed

3 files changed

+232
-44
lines changed

vizier/_src/algorithms/designers/gp/gp_models.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _pred_mean(
213213
return pred.predict_with_aux(features)[0].mean()
214214

215215

216-
def _train_stacked_residual_gp(
216+
def train_stacked_residual_gp(
217217
base_gp: GPState,
218218
spec: GPTrainingSpec,
219219
data: types.ModelData,
@@ -322,7 +322,7 @@ def train_gp(
322322
else:
323323
# Otherwise, we have a base GP to use - the GP trained on the last
324324
# iteration.
325-
curr_gp = _train_stacked_residual_gp(
325+
curr_gp = train_stacked_residual_gp(
326326
base_gp=curr_gp,
327327
spec=curr_spec,
328328
data=curr_data,

vizier/_src/algorithms/designers/gp_bandit.py

+87-17
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ class VizierGPBandit(vza.Designer, vza.Predictor):
129129

130130
_last_computed_gp: gp_models.GPState = attr.field(init=False)
131131

132+
# The prior GP used in transfer learning. `last_computed_gp` is trained
133+
# on the residuals of `_prior_gp`, if one is trained.
134+
_prior_gp: Optional[gp_models.GPState] = attr.field(init=False, default=None)
135+
132136
default_acquisition_optimizer_factory = vb.VectorizedOptimizerFactory(
133137
strategy_factory=es.VectorizedEagleStrategyFactory()
134138
)
@@ -205,6 +209,37 @@ def update(
205209
del all_active
206210
self._trials.extend(copy.deepcopy(completed.trials))
207211

212+
def set_priors(self, prior_studies: Sequence[vza.CompletedTrials]) -> None:
213+
"""Updates the list of prior studies for transfer learning.
214+
215+
Each element is treated as a new prior study, and will be stacked in order
216+
received - i.e. the first entry is for the first GP, the second entry is for
217+
the GP trained on the residuals of the first GP, etc.
218+
219+
See section 3.3 of https://dl.acm.org/doi/10.1145/3097983.3098043 for more
220+
information, or see `gp/gp_models.py` and `gp/transfer_learning.py`
221+
222+
Transfer learning is resilient to bad priors.
223+
224+
Multiple calls are permitted, but unadvised. Each call will trigger
225+
retraining of the prior GPs - on only the state provided to `set_priors`.
226+
State is not incrementally updated.
227+
228+
TODO: Decide on whether this method should become part of an
229+
interface.
230+
231+
Args:
232+
prior_studies: A list of lists of completed trials, with one list per
233+
prior study. The designer will train a prior GP for each list of prior
234+
trials (for each `CompletedStudy` entry), in the order received.
235+
"""
236+
self._rng, ard_rng = jax.random.split(self._rng)
237+
prior_data = [
238+
self._trials_to_data(prior_study.trials)
239+
for prior_study in prior_studies
240+
]
241+
self._prior_gp = self._train_prior_gp(priors=prior_data, ard_rng=ard_rng)
242+
208243
@property
209244
def _metric_info(self) -> vz.MetricInformation:
210245
return self._problem.metric_information.item()
@@ -286,23 +321,49 @@ def _trials_to_data(self, trials: Sequence[vz.Trial]) -> types.ModelData:
286321
return types.ModelData(model_data.features, labels)
287322

288323
@_experimental_override_allowed
289-
def _train_gp(
324+
def _create_gp_spec(
290325
self, data: types.ModelData, ard_rng: jax.random.KeyArray
291-
) -> gp_models.GPState:
292-
"""Overrideable training of a pre-computed ensemble GP."""
293-
trained_gp = gp_models.train_gp(
294-
spec=gp_models.GPTrainingSpec(
295-
ard_optimizer=self._ard_optimizer,
296-
ard_rng=ard_rng,
297-
coroutine=gp_models.get_vizier_gp_coroutine(
298-
features=data.features, linear_coef=self._linear_coef
299-
),
300-
ensemble_size=self._ensemble_size,
301-
ard_random_restarts=self._ard_random_restarts,
326+
) -> gp_models.GPTrainingSpec:
327+
"""Overrideable creation of a training spec for a GP model."""
328+
return gp_models.GPTrainingSpec(
329+
ard_optimizer=self._ard_optimizer,
330+
ard_rng=ard_rng,
331+
coroutine=gp_models.get_vizier_gp_coroutine(
332+
features=data.features, linear_coef=self._linear_coef
302333
),
303-
data=data,
334+
ensemble_size=self._ensemble_size,
335+
ard_random_restarts=self._ard_random_restarts,
304336
)
305-
return trained_gp
337+
338+
@_experimental_override_allowed
339+
def _train_prior_gp(
340+
self,
341+
priors: Sequence[types.ModelData],
342+
ard_rng: jax.random.KeyArray,
343+
):
344+
"""Trains a transfer-learning-enabled GP with prior studies.
345+
346+
Args:
347+
priors: Data for each sequential prior to train for transfer learning.
348+
Assumed to be in order of training, i.e. element 0 is priors[0] is the
349+
first GP trained, and priors[1] trains a GP on the residuals of the GP
350+
trained on priors[0], and so on.
351+
ard_rng: RNG to do ARD to optimize GP parameters.
352+
353+
Returns:
354+
A trained pre-computed ensemble GP.
355+
"""
356+
ard_rngs = jax.random.split(ard_rng, len(priors))
357+
358+
# Order `specs` in training order, i.e. `specs[0]` is trained first.
359+
specs = [
360+
self._create_gp_spec(prior_data, ard_rngs[i])
361+
for i, prior_data in enumerate(priors)
362+
]
363+
364+
# `train_gp` expects `specs` and `data` in training order, which is how
365+
# they were prepared above.
366+
return gp_models.train_gp(spec=specs, data=priors)
306367

307368
@profiler.record_runtime
308369
def _update_gp(self, data: types.ModelData) -> gp_models.GPState:
@@ -312,7 +373,7 @@ def _update_gp(self, data: types.ModelData) -> gp_models.GPState:
312373
data: Data to go into GP.
313374
314375
Returns:
315-
GPBanditState object containing the designer's state.
376+
`GPState` object containing the trained GP.
316377
317378
1. Convert trials to features and labels.
318379
2. Trains a pre-computed ensemble GP.
@@ -324,8 +385,16 @@ def _update_gp(self, data: types.ModelData) -> gp_models.GPState:
324385
# state. The assumption is that trials can't be removed.
325386
return self._last_computed_gp
326387
self._incorporated_trials_count = len(self._trials)
388+
327389
self._rng, ard_rng = jax.random.split(self._rng, 2)
328-
self._last_computed_gp = self._train_gp(data=data, ard_rng=ard_rng)
390+
spec = self._create_gp_spec(data, ard_rng)
391+
if self._prior_gp:
392+
self._last_computed_gp = gp_models.train_stacked_residual_gp(
393+
base_gp=self._prior_gp, spec=spec, data=data
394+
)
395+
else:
396+
self._last_computed_gp = gp_models.train_gp(spec=spec, data=data)
397+
329398
return self._last_computed_gp
330399

331400
@_experimental_override_allowed
@@ -437,7 +506,8 @@ def sample(
437506
if not trials:
438507
return np.zeros((num_samples, 0))
439508

440-
gp = self._update_gp(self._trials_to_data(self._trials))
509+
data = self._trials_to_data(self._trials)
510+
gp = self._update_gp(data)
441511
xs = self._converter.to_features(trials)
442512
xs = types.ModelInput(
443513
continuous=xs.continuous.replace_fill_value(0.0),

vizier/_src/algorithms/designers/gp_bandit_test.py

+143-25
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
"""Tests for gp_bandit."""
1818

19+
from typing import Callable
1920
from unittest import mock
2021

2122
import jax
@@ -47,6 +48,69 @@ def _build_mock_continuous_array_specs(n):
4748
return [continuous_spec] * n
4849

4950

51+
def _setup_lambda_search(
52+
f: Callable[[float], float], num_trials: int = 100
53+
) -> tuple[gp_bandit.VizierGPBandit, list[vz.Trial], vz.ProblemStatement]:
54+
"""Sets up a GP designer and outputs completed studies for `f`.
55+
56+
Args:
57+
f: 1D objective to be optimized, i.e. f(x), where x is a scalar in [-5., 5.)
58+
num_trials: Number of mock "evaluated" trials to return.
59+
60+
Returns:
61+
A GP designer set up for the problem of optimizing the objective, without any
62+
data updated.
63+
Evaluated trials against `f`.
64+
"""
65+
assert (
66+
num_trials > 0
67+
), f'Must provide a positive number of trials. Got {num_trials}.'
68+
69+
search_space = vz.SearchSpace()
70+
search_space.root.add_float_param('x0', -5.0, 5.0)
71+
problem = vz.ProblemStatement(
72+
search_space=search_space,
73+
metric_information=vz.MetricsConfig(
74+
metrics=[
75+
vz.MetricInformation('obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE),
76+
]
77+
),
78+
)
79+
80+
suggestions = quasi_random.QuasiRandomDesigner(
81+
problem.search_space, seed=1
82+
).suggest(num_trials)
83+
84+
obs_trials = []
85+
for idx, suggestion in enumerate(suggestions):
86+
trial = suggestion.to_trial(idx)
87+
x = suggestions[idx].parameters['x0'].value
88+
trial.complete(vz.Measurement(metrics={'obj': f(x)}))
89+
obs_trials.append(trial)
90+
91+
gp_designer = gp_bandit.VizierGPBandit(problem, ard_optimizer=ard_optimizer)
92+
return gp_designer, obs_trials, problem
93+
94+
95+
def _compute_mse(
96+
designer: gp_bandit.VizierGPBandit,
97+
test_trials: list[vz.Trial],
98+
y_test: list[float],
99+
) -> float:
100+
"""Evaluate the designer's accuracy on the test set.
101+
102+
Args:
103+
designer: The GP bandit designer to predict from.
104+
test_trials: The trials of the test set
105+
y_test: The results of the test set
106+
107+
Returns:
108+
The MSE of `designer` on `test_trials` and `y_test`
109+
"""
110+
preds = designer.predict(test_trials)
111+
return np.sum(np.square(preds.mean - y_test))
112+
113+
50114
class GoogleGpBanditTest(parameterized.TestCase):
51115

52116
@parameterized.parameters(
@@ -216,32 +280,8 @@ def test_on_flat_mixed_space(
216280
self.assertFalse(np.isnan(prediction.stddev).any())
217281

218282
def test_prediction_accuracy(self):
219-
search_space = vz.SearchSpace()
220-
search_space.root.add_float_param('x0', -5.0, 5.0)
221-
problem = vz.ProblemStatement(
222-
search_space=search_space,
223-
metric_information=vz.MetricsConfig(
224-
metrics=[
225-
vz.MetricInformation(
226-
'obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE
227-
),
228-
]
229-
),
230-
)
231283
f = lambda x: -((x - 0.5) ** 2)
232-
233-
suggestions = quasi_random.QuasiRandomDesigner(
234-
problem.search_space, seed=1
235-
).suggest(100)
236-
237-
obs_trials = []
238-
for idx, suggestion in enumerate(suggestions):
239-
trial = suggestion.to_trial(idx)
240-
x = suggestions[idx].parameters['x0'].value
241-
trial.complete(vz.Measurement(metrics={'obj': f(x)}))
242-
obs_trials.append(trial)
243-
244-
gp_designer = gp_bandit.VizierGPBandit(problem, ard_optimizer=ard_optimizer)
284+
gp_designer, obs_trials, _ = _setup_lambda_search(f)
245285
gp_designer.update(vza.CompletedTrials(obs_trials), vza.ActiveTrials())
246286
pred_trial = vz.Trial({'x0': 0.0})
247287
pred = gp_designer.predict([pred_trial])
@@ -261,6 +301,7 @@ def test_jit_once(self, *args):
261301
name='metric', goal=vz.ObjectiveMetricGoal.MAXIMIZE
262302
)
263303
)
304+
264305
def create_designer(problem):
265306
return gp_bandit.VizierGPBandit(
266307
problem=problem,
@@ -299,6 +340,83 @@ def create_runner(problem):
299340
create_runner(problem).run_designer(designer2)
300341

301342

343+
class GPBanditPriorsTest(parameterized.TestCase):
344+
345+
def test_prior_warping(self):
346+
"""Tests linear transform of objective has no impact on transfer learning."""
347+
f = lambda x: -((x - 0.5) ** 2)
348+
transform_f = lambda x: -3 * ((x - 0.5) ** 2) + 10
349+
350+
# X is in range of what is defined in `_setup_lambda_search`, [-5.0, 5.0)
351+
x_test = np.random.default_rng(1).uniform(-5.0, 5.0, 100)
352+
y_test = [transform_f(x) for x in x_test]
353+
test_trials = [vz.Trial({'x0': x}) for x in x_test]
354+
355+
# Create the designer with a prior and the trials to train the prior.
356+
gp_designer_with_prior, obs_trials_for_prior, _ = _setup_lambda_search(
357+
f=f, num_trials=100
358+
)
359+
360+
# Set priors to above trials.
361+
gp_designer_with_prior.set_priors(
362+
[vza.CompletedTrials(obs_trials_for_prior)]
363+
)
364+
365+
# Create a no prior designer on the transformed function `transform_f`.
366+
# Also use the generated trials to update both the designer with prior and
367+
# the designer without. This tests that the prior designer is resilient
368+
# to linear transforms between the prior and the top level study.
369+
gp_designer_no_prior, obs_trials, _ = _setup_lambda_search(
370+
f=transform_f, num_trials=20
371+
)
372+
373+
# Update both designers with the actual study.
374+
gp_designer_no_prior.update(
375+
vza.CompletedTrials(obs_trials), vza.ActiveTrials()
376+
)
377+
gp_designer_with_prior.update(
378+
vza.CompletedTrials(obs_trials), vza.ActiveTrials()
379+
)
380+
381+
# Evaluate the no prior designer's accuracy on the test set.
382+
mse_no_prior = _compute_mse(gp_designer_no_prior, test_trials, y_test)
383+
384+
# Evaluate the designer with prior's accuracy on the test set.
385+
mse_with_prior = _compute_mse(gp_designer_with_prior, test_trials, y_test)
386+
387+
# The designer with a prior should predict better.
388+
self.assertLess(mse_with_prior, mse_no_prior)
389+
390+
@parameterized.parameters(
391+
dict(iters=3, batch_size=5),
392+
dict(iters=5, batch_size=1),
393+
)
394+
def test_run_with_priors(self, *, iters, batch_size):
395+
f = lambda x: -((x - 0.5) ** 2)
396+
397+
# Create the designer with a prior and the trials to train the prior.
398+
gp_designer_with_prior, obs_trials_for_prior, problem = (
399+
_setup_lambda_search(f=f, num_trials=100)
400+
)
401+
402+
# Set priors to the above trials.
403+
gp_designer_with_prior.set_priors(
404+
[vza.CompletedTrials(obs_trials_for_prior)]
405+
)
406+
407+
self.assertLen(
408+
test_runners.RandomMetricsRunner(
409+
problem,
410+
iters=iters,
411+
batch_size=batch_size,
412+
verbose=1,
413+
validate_parameters=True,
414+
seed=1,
415+
).run_designer(gp_designer_with_prior),
416+
iters * batch_size,
417+
)
418+
419+
302420
if __name__ == '__main__':
303421
# Jax disables float64 computations by default and will silently convert
304422
# float64s to float32s. We must explicitly enable float64.

0 commit comments

Comments
 (0)