77
88import torch
99from botorch .exceptions import UnsupportedError
10+ from botorch .exceptions .warnings import BotorchWarning
1011from botorch .models import (
1112 HeteroskedasticSingleTaskGP ,
1213 ModelListGP ,
1617from botorch .models .converter import (
1718 batched_multi_output_to_single_output ,
1819 batched_to_model_list ,
20+ DEPRECATION_MESSAGE ,
1921 model_list_to_batched ,
2022)
2123from botorch .models .transforms .input import AppendFeatures , Normalize
2527from gpytorch .kernels import RBFKernel
2628from gpytorch .likelihoods import GaussianLikelihood
2729from gpytorch .likelihoods .gaussian_likelihood import FixedNoiseGaussianLikelihood
30+ from gpytorch .priors import LogNormalPrior
2831
2932
3033class TestConverters (BotorchTestCase ):
@@ -41,7 +44,8 @@ def test_batched_to_model_list(self):
4144 self .assertIsInstance (list_gp .models [0 ].likelihood , GaussianLikelihood )
4245 # test observed noise
4346 batch_gp = SingleTaskGP (train_X , train_Y , torch .rand_like (train_Y ))
44- list_gp = batched_to_model_list (batch_gp )
47+ with self .assertWarnsRegex (DeprecationWarning , DEPRECATION_MESSAGE ):
48+ list_gp = batched_to_model_list (batch_gp )
4549 self .assertIsInstance (list_gp , ModelListGP )
4650 self .assertIsInstance (
4751 list_gp .models [0 ].likelihood , FixedNoiseGaussianLikelihood
@@ -108,7 +112,8 @@ def test_model_list_to_batched(self):
108112 self .assertIsInstance (batch_gp , SingleTaskGP )
109113 self .assertIsInstance (batch_gp .likelihood , GaussianLikelihood )
110114 # test degenerate (single model)
111- batch_gp = model_list_to_batched (ModelListGP (gp1 ))
115+ with self .assertWarnsRegex (DeprecationWarning , DEPRECATION_MESSAGE ):
116+ batch_gp = model_list_to_batched (ModelListGP (gp1 ))
112117 self .assertEqual (batch_gp ._num_outputs , 1 )
113118 # test mixing different likelihoods
114119 gp2 = SingleTaskGP (train_X , train_Y1 , torch .ones_like (train_Y1 ))
@@ -240,6 +245,27 @@ def test_model_list_to_batched(self):
240245 with self .assertRaises (UnsupportedError ):
241246 model_list_to_batched (list_gp )
242247
248+ def test_model_list_to_batched_with_different_prior (self ) -> None :
249+ # The goal is to test priors that don't have their parameters
250+ # recorded in the state dict.
251+ train_X = torch .rand (10 , 2 , device = self .device , dtype = torch .double )
252+ gp1 = SingleTaskGP (
253+ train_X = train_X ,
254+ train_Y = train_X .sum (dim = - 1 , keepdim = True ),
255+ covar_module = RBFKernel (
256+ ard_num_dims = 2 , lengthscale_prior = LogNormalPrior (3.0 , 6.0 )
257+ ),
258+ )
259+ gp2 = SingleTaskGP (
260+ train_X = train_X ,
261+ train_Y = train_X .max (dim = - 1 , keepdim = True ).values ,
262+ covar_module = RBFKernel (
263+ ard_num_dims = 2 , lengthscale_prior = LogNormalPrior (2.0 , 4.0 )
264+ ),
265+ )
266+ with self .assertWarnsRegex (BotorchWarning , "Model converter cannot verify" ):
267+ model_list_to_batched (ModelListGP (gp1 , gp2 ))
268+
243269 def test_roundtrip (self ):
244270 for dtype in (torch .float , torch .double ):
245271 train_X = torch .rand (10 , 2 , device = self .device , dtype = dtype )
@@ -288,7 +314,10 @@ def test_batched_multi_output_to_single_output(self):
288314 dim = 1 ,
289315 )
290316 batched_mo_model = SingleTaskGP (train_X , train_Y )
291- batched_so_model = batched_multi_output_to_single_output (batched_mo_model )
317+ with self .assertWarnsRegex (DeprecationWarning , DEPRECATION_MESSAGE ):
318+ batched_so_model = batched_multi_output_to_single_output (
319+ batched_mo_model
320+ )
292321 self .assertIsInstance (batched_so_model , SingleTaskGP )
293322 self .assertEqual (batched_so_model .num_outputs , 1 )
294323 # test non-batched models
0 commit comments