Skip to content

Commit ff4afa1

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Deprecate model conversion code
Summary: These utilities are scarcely used and poorly maintained. They also are not fully compatible with different GPyTorch priors that we plan to use by default in the near future. Marking them for deprecation in v0.13. Also added an explicit warning when used with priors that do not have any state dict, since such priors cannot be verified to be compatible across submodels. Differential Revision: D59813960
1 parent 6892be9 commit ff4afa1

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

botorch/models/converter.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
from __future__ import annotations
1212

13+
import warnings
1314
from copy import deepcopy
1415
from typing import Dict, Optional, Set, Tuple
1516

1617
import torch
1718
from botorch.exceptions import UnsupportedError
19+
from botorch.exceptions.warnings import BotorchWarning
1820
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
1921
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
2022
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
@@ -24,7 +26,13 @@
2426
from botorch.models.transforms.outcome import OutcomeTransform
2527
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
2628
from torch import Tensor
27-
from torch.nn import Module
29+
from torch.nn import Module, ModuleList
30+
31+
DEPRECATION_MESSAGE = (
32+
"Model converter code is deprecated and will be removed in v0.13 release. "
33+
"Its correct behavior is dependent on some assumptions about model priors "
34+
"that do not always hold. Use it at your own risk!"
35+
)
2836

2937

3038
def _get_module(module: Module, name: str) -> Module:
@@ -49,15 +57,25 @@ def _get_module(module: Module, name: str) -> Module:
4957
return current
5058

5159

52-
def _check_compatibility(models: ModelListGP) -> None:
53-
"""Check if a ModelListGP can be converted."""
60+
def _check_compatibility(models: ModuleList) -> None:
61+
"""Check if the submodels of a ModelListGP are compatible with the converter."""
5462
# Check that all submodules are of the same type.
5563
for modn, mod in models[0].named_modules():
5664
mcls = mod.__class__
5765
if not all(isinstance(_get_module(m, modn), mcls) for m in models[1:]):
5866
raise UnsupportedError(
5967
"Sub-modules must be of the same type across models."
6068
)
69+
if "prior" in modn and len(mod.state_dict()) == 0:
70+
warnings.warn(
71+
"Model converter cannot verify compatibility of GPyTorch priors "
72+
"that do not register their parameters as buffers. If the prior "
73+
"is different than the default prior set by the model constructor "
74+
"this may not work correctly. Use it at your own risk! See "
75+
"https://github.com/cornellius-gp/gpytorch/issues/2550.",
76+
BotorchWarning,
77+
stacklevel=3,
78+
)
6179

6280
# Check that each model is a BatchedMultiOutputGPyTorchModel.
6381
if not all(isinstance(m, BatchedMultiOutputGPyTorchModel) for m in models):
@@ -128,6 +146,7 @@ def model_list_to_batched(model_list: ModelListGP) -> BatchedMultiOutputGPyTorch
128146
>>> list_gp = ModelListGP(gp1, gp2)
129147
>>> batch_gp = model_list_to_batched(list_gp)
130148
"""
149+
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
131150
was_training = model_list.training
132151
model_list.train()
133152
models = model_list.models
@@ -260,6 +279,7 @@ def batched_to_model_list(batch_model: BatchedMultiOutputGPyTorchModel) -> Model
260279
>>> batch_gp = SingleTaskGP(train_X, train_Y)
261280
>>> list_gp = batched_to_model_list(batch_gp)
262281
"""
282+
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
263283
was_training = batch_model.training
264284
batch_model.train()
265285
# TODO: Add support for HeteroskedasticSingleTaskGP.
@@ -363,6 +383,7 @@ def batched_multi_output_to_single_output(
363383
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
364384
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
365385
"""
386+
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
366387
was_training = batch_mo_model.training
367388
batch_mo_model.train()
368389
# TODO: Add support for HeteroskedasticSingleTaskGP.

test/models/test_converter.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from botorch.exceptions import UnsupportedError
10+
from botorch.exceptions.warnings import BotorchWarning
1011
from botorch.models import (
1112
HeteroskedasticSingleTaskGP,
1213
ModelListGP,
@@ -16,6 +17,7 @@
1617
from botorch.models.converter import (
1718
batched_multi_output_to_single_output,
1819
batched_to_model_list,
20+
DEPRECATION_MESSAGE,
1921
model_list_to_batched,
2022
)
2123
from botorch.models.transforms.input import AppendFeatures, Normalize
@@ -25,6 +27,7 @@
2527
from gpytorch.kernels import RBFKernel
2628
from gpytorch.likelihoods import GaussianLikelihood
2729
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
30+
from gpytorch.priors import LogNormalPrior
2831

2932

3033
class TestConverters(BotorchTestCase):
@@ -41,7 +44,8 @@ def test_batched_to_model_list(self):
4144
self.assertIsInstance(list_gp.models[0].likelihood, GaussianLikelihood)
4245
# test observed noise
4346
batch_gp = SingleTaskGP(train_X, train_Y, torch.rand_like(train_Y))
44-
list_gp = batched_to_model_list(batch_gp)
47+
with self.assertWarnsRegex(DeprecationWarning, DEPRECATION_MESSAGE):
48+
list_gp = batched_to_model_list(batch_gp)
4549
self.assertIsInstance(list_gp, ModelListGP)
4650
self.assertIsInstance(
4751
list_gp.models[0].likelihood, FixedNoiseGaussianLikelihood
@@ -108,7 +112,8 @@ def test_model_list_to_batched(self):
108112
self.assertIsInstance(batch_gp, SingleTaskGP)
109113
self.assertIsInstance(batch_gp.likelihood, GaussianLikelihood)
110114
# test degenerate (single model)
111-
batch_gp = model_list_to_batched(ModelListGP(gp1))
115+
with self.assertWarnsRegex(DeprecationWarning, DEPRECATION_MESSAGE):
116+
batch_gp = model_list_to_batched(ModelListGP(gp1))
112117
self.assertEqual(batch_gp._num_outputs, 1)
113118
# test mixing different likelihoods
114119
gp2 = SingleTaskGP(train_X, train_Y1, torch.ones_like(train_Y1))
@@ -240,6 +245,27 @@ def test_model_list_to_batched(self):
240245
with self.assertRaises(UnsupportedError):
241246
model_list_to_batched(list_gp)
242247

248+
def test_model_list_to_batched_with_different_prior(self) -> None:
249+
# The goal is to test priors that don't have their parameters
250+
# recorded in the state dict.
251+
train_X = torch.rand(10, 2, device=self.device, dtype=torch.double)
252+
gp1 = SingleTaskGP(
253+
train_X=train_X,
254+
train_Y=train_X.sum(dim=-1, keepdim=True),
255+
covar_module=RBFKernel(
256+
ard_num_dims=2, lengthscale_prior=LogNormalPrior(3.0, 6.0)
257+
),
258+
)
259+
gp2 = SingleTaskGP(
260+
train_X=train_X,
261+
train_Y=train_X.max(dim=-1, keepdim=True).values,
262+
covar_module=RBFKernel(
263+
ard_num_dims=2, lengthscale_prior=LogNormalPrior(2.0, 4.0)
264+
),
265+
)
266+
with self.assertWarnsRegex(BotorchWarning, "Model converter cannot verify"):
267+
model_list_to_batched(ModelListGP(gp1, gp2))
268+
243269
def test_roundtrip(self):
244270
for dtype in (torch.float, torch.double):
245271
train_X = torch.rand(10, 2, device=self.device, dtype=dtype)
@@ -288,7 +314,10 @@ def test_batched_multi_output_to_single_output(self):
288314
dim=1,
289315
)
290316
batched_mo_model = SingleTaskGP(train_X, train_Y)
291-
batched_so_model = batched_multi_output_to_single_output(batched_mo_model)
317+
with self.assertWarnsRegex(DeprecationWarning, DEPRECATION_MESSAGE):
318+
batched_so_model = batched_multi_output_to_single_output(
319+
batched_mo_model
320+
)
292321
self.assertIsInstance(batched_so_model, SingleTaskGP)
293322
self.assertEqual(batched_so_model.num_outputs, 1)
294323
# test non-batched models

0 commit comments

Comments
 (0)