Skip to content

Commit fdaea96

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
ThompsonSampling acquisition function (#2443)
Summary: Pull Request resolved: #2443 Thompson sampling (approx with RFF & pathwise) as an acquisition function to have it fit with general BO loops (&MBM, although secondary ATM). Amend: Removed Fully Bayesian variant, since it did not make sense in its current format. Differential Revision: D59961584
1 parent 3a5ec0f commit fdaea96

File tree

3 files changed

+228
-0
lines changed

3 files changed

+228
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Optional
7+
8+
import torch
9+
from botorch.acquisition.analytic import AcquisitionFunction
10+
from botorch.acquisition.objective import PosteriorTransform
11+
from botorch.models.model import Model
12+
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
13+
from botorch.utils.transforms import t_batch_mode_transform
14+
from torch import Tensor
15+
16+
17+
BATCH_SIZE_CHANGE_ERROR = """The batch size of PathwiseThompsonSampling should \
18+
not change during a forward pass - was {}, now {}. Please re-initialize the \
19+
acquisition if you want to change the batch size."""
20+
21+
22+
class PathwiseThompsonSampling(AcquisitionFunction):
23+
r"""Single-outcome Thompson Sampling packaged as an (analytic)
24+
acquisition function. Querying the acquisition function gives the summed
25+
values of one or more draws from a pathwise drawn posterior sample, and thus
26+
it maximization yields one (or multiple) Thompson sample(s).
27+
28+
Example:
29+
>>> model = SingleTaskGP(train_X, train_Y)
30+
>>> TS = PathwiseThompsonSampling(model)
31+
"""
32+
33+
def __init__(
34+
self,
35+
model: Model,
36+
posterior_transform: Optional[PosteriorTransform] = None,
37+
) -> None:
38+
r"""Single-outcome TS.
39+
40+
Args:
41+
model: A fitted GP model.
42+
posterior_transform: A PosteriorTransform. If using a multi-output model,
43+
a PosteriorTransform that transforms the multi-output posterior into a
44+
single-output posterior is required.
45+
"""
46+
if model._is_fully_bayesian:
47+
raise NotImplementedError(
48+
"PathwiseThompsonSampling is not supported for fully Bayesian models",
49+
)
50+
51+
super().__init__(model=model)
52+
self.batch_size: Optional[int] = None
53+
54+
def redraw(self) -> None:
55+
self.samples = get_matheron_path_model(
56+
model=self.model, sample_shape=torch.Size([self.batch_size])
57+
)
58+
59+
@t_batch_mode_transform()
60+
def forward(self, X: Tensor) -> Tensor:
61+
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
62+
63+
Args:
64+
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
65+
66+
Returns:
67+
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
68+
evaluations on the posterior sample draws.
69+
"""
70+
batch_size = X.shape[-2]
71+
q_dim = -2
72+
73+
# batch_shape x q x 1 x d
74+
X = X.unsqueeze(-2)
75+
if self.batch_size is None:
76+
self.batch_size = batch_size
77+
self.redraw()
78+
elif self.batch_size != batch_size:
79+
raise ValueError(
80+
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
81+
)
82+
83+
# posterior_values.shape post-squeeze:
84+
# batch_shape x q x m
85+
posterior_values = self.samples(X).squeeze(-2)
86+
# sum over batch dim and squeeze num_objectives dim (-1)
87+
return posterior_values.sum(q_dim).squeeze(-1)

sphinx/source/acquisition.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ Risk Measures
172172
.. automodule:: botorch.acquisition.risk_measures
173173
:members:
174174

175+
Thompson Sampling
176+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177+
.. automodule:: botorch.acquisition.thompson_sampling
178+
:members:
179+
175180
Multi-Output Risk Measures
176181
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
177182
.. automodule:: botorch.acquisition.multi_objective.multi_output_risk_measures
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from itertools import product
8+
9+
import torch
10+
from botorch.acquisition.thompson_sampling import PathwiseThompsonSampling
11+
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
12+
13+
from botorch.models.gp_regression import SingleTaskGP
14+
from botorch.models.model import Model
15+
from botorch.models.transforms.outcome import Standardize
16+
from botorch.utils.testing import BotorchTestCase
17+
18+
19+
def get_model(train_X, train_Y, standardize_model):
20+
if standardize_model:
21+
outcome_transform = Standardize(m=1)
22+
23+
else:
24+
outcome_transform = None
25+
model = SingleTaskGP(
26+
train_X=train_X,
27+
train_Y=train_Y,
28+
outcome_transform=outcome_transform,
29+
)
30+
return model
31+
32+
33+
def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
34+
35+
mcmc_samples = {
36+
"lengthscale": torch.rand(num_samples, 1, dim, **tkwargs),
37+
"outputscale": torch.rand(num_samples, **tkwargs),
38+
"mean": torch.randn(num_samples, **tkwargs),
39+
}
40+
if infer_noise:
41+
mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs)
42+
return mcmc_samples
43+
44+
45+
def get_fully_bayesian_model(
46+
train_X,
47+
train_Y,
48+
num_models,
49+
**tkwargs,
50+
):
51+
52+
model = SaasFullyBayesianSingleTaskGP(
53+
train_X=train_X,
54+
train_Y=train_Y,
55+
)
56+
mcmc_samples = _get_mcmc_samples(
57+
num_samples=num_models,
58+
dim=train_X.shape[-1],
59+
infer_noise=True,
60+
**tkwargs,
61+
)
62+
model.load_mcmc_samples(mcmc_samples)
63+
return model
64+
65+
66+
class TestPathwiseThompsonSampling(BotorchTestCase):
67+
def _test_thompson_sampling_base(self, model: Model):
68+
acq = PathwiseThompsonSampling(
69+
model=model,
70+
)
71+
X_observed = model.train_inputs[0]
72+
input_dim = X_observed.shape[-1]
73+
test_X = torch.rand(4, 1, input_dim).to(X_observed)
74+
# re-draw samples and expect other output
75+
acq_pass = acq(test_X)
76+
self.assertTrue(acq_pass.shape == test_X.shape[:-2])
77+
78+
acq_pass1 = acq(test_X)
79+
self.assertAllClose(acq_pass1, acq(test_X))
80+
acq.redraw()
81+
acq_pass2 = acq(test_X)
82+
self.assertFalse(torch.allclose(acq_pass1, acq_pass2))
83+
84+
def _test_thompson_sampling_batch(self, model: Model):
85+
X_observed = model.train_inputs[0]
86+
input_dim = X_observed.shape[-1]
87+
batch_acq = PathwiseThompsonSampling(
88+
model=model,
89+
)
90+
self.assertEqual(batch_acq.batch_size, None)
91+
test_X = torch.rand(4, 5, input_dim).to(X_observed)
92+
batch_acq(test_X)
93+
self.assertEqual(batch_acq.batch_size, 5)
94+
test_X = torch.rand(4, 7, input_dim).to(X_observed)
95+
with self.assertRaisesRegex(
96+
ValueError,
97+
"The batch size of PathwiseThompsonSampling should not "
98+
"change during a forward pass - was 5, now 7. Please re-initialize "
99+
"the acquisition if you want to change the batch size.",
100+
):
101+
batch_acq(test_X)
102+
103+
batch_acq2 = PathwiseThompsonSampling(model)
104+
test_X = torch.rand(4, 7, 1, input_dim).to(X_observed)
105+
self.assertEqual(batch_acq2(test_X).shape, test_X.shape[:-2])
106+
107+
batch_acq3 = PathwiseThompsonSampling(model)
108+
test_X = torch.rand(4, 7, 3, input_dim).to(X_observed)
109+
self.assertEqual(batch_acq3(test_X).shape, test_X.shape[:-2])
110+
111+
def test_thompson_sampling_single_task(self):
112+
input_dim = 2
113+
num_objectives = 1
114+
for dtype, standardize_model in product(
115+
(torch.float32, torch.float64), (True, False)
116+
):
117+
tkwargs = {"device": self.device, "dtype": dtype}
118+
train_X = torch.rand(4, input_dim, **tkwargs)
119+
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs)
120+
model = get_model(train_X, train_Y, standardize_model=standardize_model)
121+
self._test_thompson_sampling_base(model)
122+
self._test_thompson_sampling_batch(model)
123+
124+
def test_thompson_sampling_fully_bayesian(self):
125+
input_dim = 2
126+
num_objectives = 1
127+
tkwargs = {"device": self.device, "dtype": torch.float64}
128+
train_X = torch.rand(4, input_dim, **tkwargs)
129+
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs)
130+
131+
fb_model = get_fully_bayesian_model(train_X, train_Y, num_models=3, **tkwargs)
132+
with self.assertRaisesRegex(
133+
NotImplementedError,
134+
"PathwiseThompsonSampling is not supported for fully Bayesian models",
135+
):
136+
PathwiseThompsonSampling(model=fb_model)

0 commit comments

Comments
 (0)