Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 4 additions & 47 deletions botorch/posteriors/gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from __future__ import annotations

import warnings

from contextlib import ExitStack
from typing import Optional, Tuple, TYPE_CHECKING, Union

Expand Down Expand Up @@ -39,32 +37,13 @@ class GPyTorchPosterior(TorchPosterior):

distribution: MultivariateNormal

def __init__(
self,
distribution: Optional[MultivariateNormal] = None,
mvn: Optional[MultivariateNormal] = None,
) -> None:
def __init__(self, distribution: MultivariateNormal) -> None:
r"""A posterior based on GPyTorch's multi-variate Normal distributions.

Args:
distribution: A GPyTorch MultivariateNormal (single-output case) or
MultitaskMultivariateNormal (multi-output case).
mvn: Deprecated.
"""
if mvn is not None:
if distribution is not None:
raise RuntimeError(
"Got both a `distribution` and an `mvn` argument. "
"Use the `distribution` only."
)
warnings.warn(
"The `mvn` argument of `GPyTorchPosterior`s has been renamed to "
"`distribution` and will be removed in a future version.",
DeprecationWarning,
)
distribution = mvn
if distribution is None:
raise RuntimeError("GPyTorchPosterior must have a distribution specified.")
super().__init__(distribution=distribution)
self._is_mt = isinstance(distribution, MultitaskMultivariateNormal)

Expand Down Expand Up @@ -146,11 +125,7 @@ def rsample_from_base_samples(
samples = samples.unsqueeze(-1)
return samples

def rsample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> Tensor:
def rsample(self, sample_shape: Optional[torch.Size] = None) -> Tensor:
r"""Sample from the posterior (with gradients).

Args:
Expand All @@ -167,30 +142,12 @@ def rsample(
"""
if sample_shape is None:
sample_shape = torch.Size([1])
if base_samples is not None:
warnings.warn(
"Use of `base_samples` with `rsample` is deprecated. Use "
"`rsample_from_base_samples` instead.",
DeprecationWarning,
)
if base_samples.shape[: len(sample_shape)] != sample_shape:
raise RuntimeError(
"`sample_shape` disagrees with shape of `base_samples`. "
f"Got {sample_shape=} and {base_samples.shape=}."
)
# get base_samples to the correct shape
base_samples = base_samples.expand(self._extended_shape(sample_shape))
if not self._is_mt:
# Remove output dimension in single output case.
base_samples = base_samples.squeeze(-1)
return self.rsample_from_base_samples(
sample_shape=sample_shape, base_samples=base_samples
)

with ExitStack() as es:
if linop_settings._fast_covar_root_decomposition.is_default():
es.enter_context(linop_settings._fast_covar_root_decomposition(False))
samples = self.distribution.rsample(
sample_shape=sample_shape, base_samples=base_samples
sample_shape=sample_shape, base_samples=None
)
# make sure there always is an output dimension
if not self._is_mt:
Expand Down
8 changes: 1 addition & 7 deletions botorch/posteriors/posterior_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,13 @@ def variance(self) -> Tensor:
"""
return self._reshape_and_cat(tensors=[p.variance for p in self.posteriors])

def rsample(
self,
sample_shape: Optional[torch.Size] = None,
) -> Tensor:
def rsample(self, sample_shape: Optional[torch.Size] = None) -> Tensor:
r"""Sample from the posterior (with gradients).

Args:
sample_shape: A `torch.Size` object specifying the sample shape. To
draw `n` samples, set to `torch.Size([n])`. To draw `b` batches
of `n` samples each, set to `torch.Size([b, n])`.
base_samples: An (optional) Tensor of `N(0, I)` base samples of
appropriate dimension, typically obtained from a `Sampler`.
This is used for deterministic optimization. Deprecated.

Returns:
Samples from the posterior, a tensor of shape
Expand Down
12 changes: 6 additions & 6 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,24 +298,24 @@ def variance(self):
def rsample(
self,
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[Tensor] = None,
) -> Tensor:
"""Mock sample by repeating self._samples. If base_samples is provided,
do a shape check but return the same mock samples."""
if sample_shape is None:
sample_shape = torch.Size()
if sample_shape is not None and base_samples is not None:
# check the base_samples shape is consistent with the sample_shape
if base_samples.shape[: len(sample_shape)] != sample_shape:
raise RuntimeError("sample_shape disagrees with base_samples.")
return self._samples.expand(sample_shape + self._samples.shape)

def rsample_from_base_samples(
self,
sample_shape: torch.Size,
base_samples: Tensor,
) -> Tensor:
return self.rsample(sample_shape, base_samples)
if base_samples.shape[: len(sample_shape)] != sample_shape:
raise RuntimeError(
"`sample_shape` disagrees with shape of `base_samples`. "
f"Got {sample_shape=} and {base_samples.shape=}."
)
return self.rsample(sample_shape)


@GetSampler.register(MockPosterior)
Expand Down
47 changes: 2 additions & 45 deletions test/posteriors/test_gpytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from botorch.exceptions import BotorchTensorDimensionError
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
from botorch.utils.testing import _get_test_posterior, BotorchTestCase, MockPosterior
from botorch.utils.testing import _get_test_posterior, BotorchTestCase
from gpytorch import settings as gpt_settings
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from linear_operator.operators import to_linear_operator
Expand All @@ -25,18 +25,7 @@


class TestGPyTorchPosterior(BotorchTestCase):
def test_GPyTorchPosterior(self):
# Test init & mvn property.
mock_mvn = MockPosterior()
with self.assertWarnsRegex(DeprecationWarning, "The `mvn` argument of"):
posterior = GPyTorchPosterior(mvn=mock_mvn)
self.assertIs(posterior.mvn, mock_mvn)
self.assertIs(posterior.distribution, mock_mvn)
with self.assertRaisesRegex(RuntimeError, "Got both a `distribution`"):
GPyTorchPosterior(mvn=mock_mvn, distribution=mock_mvn)
with self.assertRaisesRegex(RuntimeError, "GPyTorchPosterior must have"):
GPyTorchPosterior()

def test_GPyTorchPosterior(self) -> None:
for dtype in (torch.float, torch.double):
n = 3
mean = torch.rand(n, dtype=dtype, device=self.device)
Expand Down Expand Up @@ -80,10 +69,6 @@ def test_GPyTorchPosterior(self):
posterior.rsample_from_base_samples(
sample_shape=torch.Size([3]), base_samples=base_samples
)
with self.assertRaisesRegex(RuntimeError, "sample_shape"):
posterior.rsample(
sample_shape=torch.Size([3]), base_samples=base_samples
)
# ensure consistent result
for sample_shape in ([4], [4, 2]):
base_samples = torch.randn(
Expand All @@ -109,20 +94,6 @@ def test_GPyTorchPosterior(self):
self.assertEqual(density.shape, posterior._extended_shape(torch.Size([2])))
expected = torch.stack([marginal.log_prob(q).exp() for q in q_val], dim=0)
self.assertAllClose(density, expected)
# collapse_batch_dims
b_mean = torch.rand(2, 3, dtype=dtype, device=self.device)
b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=self.device)
b_covar = torch.diag_embed(b_variance)
b_mvn = MultivariateNormal(b_mean, to_linear_operator(b_covar))
b_posterior = GPyTorchPosterior(distribution=b_mvn)
b_base_samples = torch.randn(4, 1, 3, 1, device=self.device, dtype=dtype)
with self.assertWarnsRegex(
DeprecationWarning, "`base_samples` with `rsample`"
):
b_samples = b_posterior.rsample(
sample_shape=torch.Size([4]), base_samples=b_base_samples
)
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))

def test_GPyTorchPosterior_Multitask(self):
for dtype in (torch.float, torch.double):
Expand Down Expand Up @@ -159,20 +130,6 @@ def test_GPyTorchPosterior_Multitask(self):
sample_shape=torch.Size([4, 2]), base_samples=base_samples2
)
self.assertAllClose(samples2_b1, samples2_b2)
# collapse_batch_dims
b_mean = torch.rand(2, 3, 2, dtype=dtype, device=self.device)
b_variance = 1 + torch.rand(2, 3, 2, dtype=dtype, device=self.device)
b_covar = torch.diag_embed(b_variance.view(2, 6))
b_mvn = MultitaskMultivariateNormal(b_mean, to_linear_operator(b_covar))
b_posterior = GPyTorchPosterior(distribution=b_mvn)
b_base_samples = torch.randn(4, 1, 3, 2, device=self.device, dtype=dtype)
with self.assertWarnsRegex(
DeprecationWarning, "`base_samples` with `rsample`"
):
b_samples = b_posterior.rsample(
sample_shape=torch.Size([4]), base_samples=b_base_samples
)
self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 2]))

def test_degenerate_GPyTorchPosterior(self):
for dtype, multi_task in (
Expand Down
21 changes: 14 additions & 7 deletions test/utils/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior


class TestMock(BotorchTestCase):
def test_MockPosterior(self):
# test basic logic
class TestMockPosterior(BotorchTestCase):
def test_basic_logic(self) -> None:
mp = MockPosterior()
self.assertEqual(mp.device.type, "cpu")
self.assertEqual(mp.dtype, torch.float32)
self.assertEqual(mp._extended_shape(), torch.Size())
self.assertEqual(
MockPosterior(variance=torch.rand(2))._extended_shape(), torch.Size([2])
)
# test passing in tensors

def test_passing_tensors(self) -> None:
mean = torch.rand(2)
variance = torch.eye(2)
samples = torch.rand(1, 2)
Expand All @@ -31,10 +31,17 @@ def test_MockPosterior(self):
self.assertTrue(
torch.all(mp.rsample(torch.Size([2])) == samples.repeat(2, 1, 1))
)
with self.assertRaises(RuntimeError):
mp.rsample(sample_shape=torch.Size([2]), base_samples=torch.rand(3))

def test_MockModel(self):
def test_rsample_from_base_samples(self) -> None:
mp = MockPosterior()
with self.assertRaisesRegex(
RuntimeError, "`sample_shape` disagrees with shape of `base_samples`."
):
mp.rsample_from_base_samples(torch.zeros(2, 2), torch.zeros(3))


class TestMockModel(BotorchTestCase):
def test_basic(self) -> None:
mp = MockPosterior()
mm = MockModel(mp)
X = torch.empty(0)
Expand Down