Skip to content

Commit 5c9af85

Browse files
esantorellafacebook-github-bot
authored andcommitted
Reap deprecated mvn argument to GPyTorchPosterior (#2255)
Summary: All internal calls to this have been updated to use `distribution` instead, including in subclasses (`HigherOrderGPPosterior`, `GaussianMixturePosterior`, `MultiTaskGPPosterior`, `FullyBayesianPosterior`) Differential Revision: D55092855
1 parent 60d9635 commit 5c9af85

File tree

2 files changed

+3
-35
lines changed

2 files changed

+3
-35
lines changed

botorch/posteriors/gpytorch.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
from __future__ import annotations
1212

13-
import warnings
14-
1513
from contextlib import ExitStack
1614
from typing import Optional, Tuple, TYPE_CHECKING, Union
1715

@@ -39,32 +37,13 @@ class GPyTorchPosterior(TorchPosterior):
3937

4038
distribution: MultivariateNormal
4139

42-
def __init__(
43-
self,
44-
distribution: Optional[MultivariateNormal] = None,
45-
mvn: Optional[MultivariateNormal] = None,
46-
) -> None:
40+
def __init__(self, distribution: MultivariateNormal) -> None:
4741
r"""A posterior based on GPyTorch's multi-variate Normal distributions.
4842
4943
Args:
5044
distribution: A GPyTorch MultivariateNormal (single-output case) or
5145
MultitaskMultivariateNormal (multi-output case).
52-
mvn: Deprecated.
5346
"""
54-
if mvn is not None:
55-
if distribution is not None:
56-
raise RuntimeError(
57-
"Got both a `distribution` and an `mvn` argument. "
58-
"Use the `distribution` only."
59-
)
60-
warnings.warn(
61-
"The `mvn` argument of `GPyTorchPosterior`s has been renamed to "
62-
"`distribution` and will be removed in a future version.",
63-
DeprecationWarning,
64-
)
65-
distribution = mvn
66-
if distribution is None:
67-
raise RuntimeError("GPyTorchPosterior must have a distribution specified.")
6847
super().__init__(distribution=distribution)
6948
self._is_mt = isinstance(distribution, MultitaskMultivariateNormal)
7049

test/posteriors/test_gpytorch.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from botorch.exceptions import BotorchTensorDimensionError
1414
from botorch.posteriors.gpytorch import GPyTorchPosterior, scalarize_posterior
15-
from botorch.utils.testing import _get_test_posterior, BotorchTestCase, MockPosterior
15+
from botorch.utils.testing import _get_test_posterior, BotorchTestCase
1616
from gpytorch import settings as gpt_settings
1717
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
1818
from linear_operator.operators import to_linear_operator
@@ -25,18 +25,7 @@
2525

2626

2727
class TestGPyTorchPosterior(BotorchTestCase):
28-
def test_GPyTorchPosterior(self):
29-
# Test init & mvn property.
30-
mock_mvn = MockPosterior()
31-
with self.assertWarnsRegex(DeprecationWarning, "The `mvn` argument of"):
32-
posterior = GPyTorchPosterior(mvn=mock_mvn)
33-
self.assertIs(posterior.mvn, mock_mvn)
34-
self.assertIs(posterior.distribution, mock_mvn)
35-
with self.assertRaisesRegex(RuntimeError, "Got both a `distribution`"):
36-
GPyTorchPosterior(mvn=mock_mvn, distribution=mock_mvn)
37-
with self.assertRaisesRegex(RuntimeError, "GPyTorchPosterior must have"):
38-
GPyTorchPosterior()
39-
28+
def test_GPyTorchPosterior(self) -> None:
4029
for dtype in (torch.float, torch.double):
4130
n = 3
4231
mean = torch.rand(n, dtype=dtype, device=self.device)

0 commit comments

Comments
 (0)