Skip to content

Commit 60d9635

Browse files
esantorellafacebook-github-bot
authored andcommitted
Reap base_samples argument of GPyTorchPosterior.rsample (#2254)
Summary: X-link: facebookexternal/botorch_fb#17 The `base_samples` argument to GPyTorchPosterior.rsample was deprecated in BoTorch 8.0.0 - reaped the deprecated code - removed the corresponding unit tests - Removed `base_samples` argument from `MockPosterior.rsample` used for testing, since the base Posterior class also does not permit `base_samples` in `rsample`. Moved an exception in `MockPosterior` so that invalid `base_samples` will still be checked. - fixed a docstring in PosteriorList Reviewed By: saitcakmak Differential Revision: D55038429
1 parent 1a8b4ea commit 60d9635

File tree

5 files changed

+24
-77
lines changed

5 files changed

+24
-77
lines changed

botorch/posteriors/gpytorch.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,7 @@ def rsample_from_base_samples(
146146
samples = samples.unsqueeze(-1)
147147
return samples
148148

149-
def rsample(
150-
self,
151-
sample_shape: Optional[torch.Size] = None,
152-
base_samples: Optional[Tensor] = None,
153-
) -> Tensor:
149+
def rsample(self, sample_shape: Optional[torch.Size] = None) -> Tensor:
154150
r"""Sample from the posterior (with gradients).
155151
156152
Args:
@@ -167,30 +163,12 @@ def rsample(
167163
"""
168164
if sample_shape is None:
169165
sample_shape = torch.Size([1])
170-
if base_samples is not None:
171-
warnings.warn(
172-
"Use of `base_samples` with `rsample` is deprecated. Use "
173-
"`rsample_from_base_samples` instead.",
174-
DeprecationWarning,
175-
)
176-
if base_samples.shape[: len(sample_shape)] != sample_shape:
177-
raise RuntimeError(
178-
"`sample_shape` disagrees with shape of `base_samples`. "
179-
f"Got {sample_shape=} and {base_samples.shape=}."
180-
)
181-
# get base_samples to the correct shape
182-
base_samples = base_samples.expand(self._extended_shape(sample_shape))
183-
if not self._is_mt:
184-
# Remove output dimension in single output case.
185-
base_samples = base_samples.squeeze(-1)
186-
return self.rsample_from_base_samples(
187-
sample_shape=sample_shape, base_samples=base_samples
188-
)
166+
189167
with ExitStack() as es:
190168
if linop_settings._fast_covar_root_decomposition.is_default():
191169
es.enter_context(linop_settings._fast_covar_root_decomposition(False))
192170
samples = self.distribution.rsample(
193-
sample_shape=sample_shape, base_samples=base_samples
171+
sample_shape=sample_shape, base_samples=None
194172
)
195173
# make sure there always is an output dimension
196174
if not self._is_mt:

botorch/posteriors/posterior_list.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,13 @@ def variance(self) -> Tensor:
154154
"""
155155
return self._reshape_and_cat(tensors=[p.variance for p in self.posteriors])
156156

157-
def rsample(
158-
self,
159-
sample_shape: Optional[torch.Size] = None,
160-
) -> Tensor:
157+
def rsample(self, sample_shape: Optional[torch.Size] = None) -> Tensor:
161158
r"""Sample from the posterior (with gradients).
162159
163160
Args:
164161
sample_shape: A `torch.Size` object specifying the sample shape. To
165162
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
166163
of `n` samples each, set to `torch.Size([b, n])`.
167-
base_samples: An (optional) Tensor of `N(0, I)` base samples of
168-
appropriate dimension, typically obtained from a `Sampler`.
169-
This is used for deterministic optimization. Deprecated.
170164
171165
Returns:
172166
Samples from the posterior, a tensor of shape

botorch/utils/testing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,24 +298,24 @@ def variance(self):
298298
def rsample(
299299
self,
300300
sample_shape: Optional[torch.Size] = None,
301-
base_samples: Optional[Tensor] = None,
302301
) -> Tensor:
303302
"""Mock sample by repeating self._samples. If base_samples is provided,
304303
do a shape check but return the same mock samples."""
305304
if sample_shape is None:
306305
sample_shape = torch.Size()
307-
if sample_shape is not None and base_samples is not None:
308-
# check the base_samples shape is consistent with the sample_shape
309-
if base_samples.shape[: len(sample_shape)] != sample_shape:
310-
raise RuntimeError("sample_shape disagrees with base_samples.")
311306
return self._samples.expand(sample_shape + self._samples.shape)
312307

313308
def rsample_from_base_samples(
314309
self,
315310
sample_shape: torch.Size,
316311
base_samples: Tensor,
317312
) -> Tensor:
318-
return self.rsample(sample_shape, base_samples)
313+
if base_samples.shape[: len(sample_shape)] != sample_shape:
314+
raise RuntimeError(
315+
"`sample_shape` disagrees with shape of `base_samples`. "
316+
f"Got {sample_shape=} and {base_samples.shape=}."
317+
)
318+
return self.rsample(sample_shape)
319319

320320

321321
@GetSampler.register(MockPosterior)

test/posteriors/test_gpytorch.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,6 @@ def test_GPyTorchPosterior(self):
8080
posterior.rsample_from_base_samples(
8181
sample_shape=torch.Size([3]), base_samples=base_samples
8282
)
83-
with self.assertRaisesRegex(RuntimeError, "sample_shape"):
84-
posterior.rsample(
85-
sample_shape=torch.Size([3]), base_samples=base_samples
86-
)
8783
# ensure consistent result
8884
for sample_shape in ([4], [4, 2]):
8985
base_samples = torch.randn(
@@ -109,20 +105,6 @@ def test_GPyTorchPosterior(self):
109105
self.assertEqual(density.shape, posterior._extended_shape(torch.Size([2])))
110106
expected = torch.stack([marginal.log_prob(q).exp() for q in q_val], dim=0)
111107
self.assertAllClose(density, expected)
112-
# collapse_batch_dims
113-
b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
114-
b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=self.device)
115-
b_covar = torch.diag_embed(b_variance)
116-
b_mvn = MultivariateNormal(b_mean, to_linear_operator(b_covar))
117-
b_posterior = GPyTorchPosterior(distribution=b_mvn)
118-
b_base_samples = torch.randn(4, 1, 3, 1, device=self.device, dtype=dtype)
119-
with self.assertWarnsRegex(
120-
DeprecationWarning, "`base_samples` with `rsample`"
121-
):
122-
b_samples = b_posterior.rsample(
123-
sample_shape=torch.Size([4]), base_samples=b_base_samples
124-
)
125-
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
126108

127109
def test_GPyTorchPosterior_Multitask(self):
128110
for dtype in (torch.float, torch.double):
@@ -159,20 +141,6 @@ def test_GPyTorchPosterior_Multitask(self):
159141
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
160142
)
161143
self.assertAllClose(samples2_b1, samples2_b2)
162-
# collapse_batch_dims
163-
b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device)
164-
b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=self.device)
165-
b_covar = torch.diag_embed(b_variance.view(2, 6))
166-
b_mvn = MultitaskMultivariateNormal(b_mean, to_linear_operator(b_covar))
167-
b_posterior = GPyTorchPosterior(distribution=b_mvn)
168-
b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype)
169-
with self.assertWarnsRegex(
170-
DeprecationWarning, "`base_samples` with `rsample`"
171-
):
172-
b_samples = b_posterior.rsample(
173-
sample_shape=torch.Size([4]), base_samples=b_base_samples
174-
)
175-
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))
176144

177145
def test_degenerate_GPyTorchPosterior(self):
178146
for dtype, multi_task in (

test/utils/test_testing.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
99

1010

11-
class TestMock(BotorchTestCase):
12-
def test_MockPosterior(self):
13-
# test basic logic
11+
class TestMockPosterior(BotorchTestCase):
12+
def test_basic_logic(self) -> None:
1413
mp = MockPosterior()
1514
self.assertEqual(mp.device.type, "cpu")
1615
self.assertEqual(mp.dtype, torch.float32)
1716
self.assertEqual(mp._extended_shape(), torch.Size())
1817
self.assertEqual(
1918
MockPosterior(variance=torch.rand(2))._extended_shape(), torch.Size([2])
2019
)
21-
# test passing in tensors
20+
21+
def test_passing_tensors(self) -> None:
2222
mean = torch.rand(2)
2323
variance = torch.eye(2)
2424
samples = torch.rand(1, 2)
@@ -31,10 +31,17 @@ def test_MockPosterior(self):
3131
self.assertTrue(
3232
torch.all(mp.rsample(torch.Size([2])) == samples.repeat(2, 1, 1))
3333
)
34-
with self.assertRaises(RuntimeError):
35-
mp.rsample(sample_shape=torch.Size([2]), base_samples=torch.rand(3))
3634

37-
def test_MockModel(self):
35+
def test_rsample_from_base_samples(self) -> None:
36+
mp = MockPosterior()
37+
with self.assertRaisesRegex(
38+
RuntimeError, "`sample_shape` disagrees with shape of `base_samples`."
39+
):
40+
mp.rsample_from_base_samples(torch.zeros(2, 2), torch.zeros(3))
41+
42+
43+
class TestMockModel(BotorchTestCase):
44+
def test_basic(self) -> None:
3845
mp = MockPosterior()
3946
mm = MockModel(mp)
4047
X = torch.empty(0)

0 commit comments

Comments
 (0)