1010import torch
1111from botorch .acquisition .objective import ScalarizedPosteriorTransform
1212from botorch .exceptions .errors import BotorchTensorDimensionError
13- from botorch .exceptions .warnings import OptimizationWarning
13+ from botorch .exceptions .warnings import InputDataWarning , OptimizationWarning
1414from botorch .fit import fit_gpytorch_mll
1515from botorch .models .latent_kronecker_gp import LatentKroneckerGP
1616from botorch .models .transforms import Normalize , Standardize
17+ from botorch .utils .datasets import SupervisedDataset
1718from botorch .utils .testing import BotorchTestCase , get_random_data
1819from botorch .utils .types import DEFAULT
1920from gpytorch .kernels import MaternKernel , RBFKernel , ScaleKernel
@@ -38,7 +39,7 @@ def _get_data_with_missing_entries(
3839 mask [torch .randperm (n_train * t )[: n_train * t // 2 ]] = False
3940 train_Y [..., ~ mask .reshape (n_train , t )] = torch .nan
4041
41- return train_X , train_T , train_Y
42+ return train_X , train_T , train_Y , mask
4243
4344
4445class TestLatentKroneckerGP (BotorchTestCase ):
@@ -71,7 +72,7 @@ def test_default_init(self):
7172 intf = None
7273 octf = None
7374
74- train_X , train_T , train_Y = _get_data_with_missing_entries (
75+ train_X , train_T , train_Y , mask = _get_data_with_missing_entries (
7576 n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
7677 )
7778
@@ -85,8 +86,7 @@ def test_default_init(self):
8586 model .to (** tkwargs )
8687
8788 # test init
88- mask_valid = torch .isfinite (train_Y .reshape (- 1 , n_train , t )[0 ]).flatten ()
89- train_Y_flat = train_Y .reshape (* batch_shape , - 1 )[..., mask_valid ]
89+ train_Y_flat = train_Y .reshape (* batch_shape , - 1 )[..., mask ]
9090 if use_transforms :
9191 self .assertIsInstance (model .input_transform , Normalize )
9292 self .assertIsInstance (model .outcome_transform , Standardize )
@@ -124,7 +124,7 @@ def test_custom_init(self):
124124 ):
125125 tkwargs = {"device" : self .device , "dtype" : dtype }
126126
127- train_X , train_T , train_Y = _get_data_with_missing_entries (
127+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
128128 n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
129129 )
130130
@@ -230,7 +230,7 @@ def test_gp_train(self):
230230 intf = None
231231 octf = None
232232
233- train_X , train_T , train_Y = _get_data_with_missing_entries (
233+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
234234 n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
235235 )
236236
@@ -271,7 +271,7 @@ def _test_gp_eval_shapes(
271271 intf = None
272272 octf = None
273273
274- train_X , train_T , train_Y = _get_data_with_missing_entries (
274+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
275275 n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
276276 )
277277
@@ -441,7 +441,7 @@ def test_gp_eval_values(self):
441441 intf = None
442442 octf = None
443443
444- train_X , train_T , train_Y = _get_data_with_missing_entries (
444+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
445445 n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
446446 )
447447
@@ -507,7 +507,7 @@ def test_iterative_methods(self):
507507 batch_shape = torch .Size ([])
508508 tkwargs = {"device" : self .device , "dtype" : torch .double }
509509
510- train_X , train_T , train_Y = _get_data_with_missing_entries (
510+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
511511 n_train = 10 , d = 1 , t = 1 , batch_shape = batch_shape , tkwargs = tkwargs
512512 )
513513
@@ -525,7 +525,7 @@ def test_not_implemented(self):
525525 batch_shape = torch .Size ([])
526526 tkwargs = {"device" : self .device , "dtype" : torch .double }
527527
528- train_X , train_T , train_Y = _get_data_with_missing_entries (
528+ train_X , train_T , train_Y , _ = _get_data_with_missing_entries (
529529 n_train = 10 , d = 1 , t = 1 , batch_shape = batch_shape , tkwargs = tkwargs
530530 )
531531
@@ -558,3 +558,63 @@ def test_not_implemented(self):
558558 err_msg = f"Only GaussianLikelihood currently supported for { cls_name } "
559559 with self .assertRaisesRegex (NotImplementedError , err_msg ):
560560 model .posterior (train_X )
561+
562+ def test_construct_inputs (self ) -> None :
563+ # This test relies on the fact that the random (missing) data generation
564+ # does not remove all occurrences of a particular X or T value. Therefore,
565+ # we fix the random seed and set n_train and t to slightly larger values.
566+
567+ torch .manual_seed (12345 )
568+ for batch_shape , n_train , d , t , dtype in itertools .product (
569+ ( # batch_shape
570+ torch .Size ([]),
571+ torch .Size ([1 ]),
572+ torch .Size ([2 ]),
573+ torch .Size ([2 , 3 ]),
574+ ),
575+ (15 ,), # n_train
576+ (1 , 2 ), # d
577+ (10 ,), # t
578+ (torch .float , torch .double ), # dtype
579+ ):
580+ tkwargs = {"device" : self .device , "dtype" : dtype }
581+
582+ train_X , train_T , train_Y , mask = _get_data_with_missing_entries (
583+ n_train = n_train , d = d , t = t , batch_shape = batch_shape , tkwargs = tkwargs
584+ )
585+
586+ train_X_supervised = torch .cat (
587+ [
588+ train_X .repeat_interleave (t , dim = - 2 ),
589+ train_T .repeat (* ([1 ] * len (batch_shape )), n_train , 1 ),
590+ ],
591+ dim = - 1 ,
592+ )
593+ train_Y_supervised = train_Y .reshape (* batch_shape , n_train * t , 1 )
594+
595+ # randomly permute data to test robustness to non-contiguous data
596+ idx = torch .randperm (n_train * t , device = self .device )
597+ train_X_supervised = train_X_supervised [..., idx , :][..., mask [idx ], :]
598+ train_Y_supervised = train_Y_supervised [..., idx , :][..., mask [idx ], :]
599+
600+ dataset = SupervisedDataset (
601+ X = train_X_supervised ,
602+ Y = train_Y_supervised ,
603+ Yvar = train_Y_supervised , # just to check warning
604+ feature_names = [f"x_{ i } " for i in range (d )] + ["step" ],
605+ outcome_names = ["y" ],
606+ )
607+
608+ w_msg = "Ignoring Yvar values in provided training data, because "
609+ w_msg += "they are currently not supported by LatentKroneckerGP."
610+ with self .assertWarnsRegex (InputDataWarning , w_msg ):
611+ model_inputs = LatentKroneckerGP .construct_inputs (dataset )
612+
613+ # this test generates train_X and train_T in sorted order
614+ # the data is randomly permuted before passing to construct_inputs
615+ # construct_inputs sorts the data, so we expect the results to be equal
616+ self .assertAllClose (model_inputs ["train_X" ], train_X , atol = 0.0 )
617+ self .assertAllClose (model_inputs ["train_T" ], train_T , atol = 0.0 )
618+ self .assertAllClose (
619+ model_inputs ["train_Y" ], train_Y , atol = 0.0 , equal_nan = True
620+ )
0 commit comments