@@ -125,13 +125,15 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
125
125
return train_X , train_Y , train_Yvar , test_X
126
126
127
127
def _get_unnormalized_condition_data (
128
- self , num_models : int , infer_noise : bool , ** tkwargs
128
+ self , num_models : int , num_cond : int , infer_noise : bool , ** tkwargs
129
129
):
130
130
with torch .random .fork_rng ():
131
131
torch .manual_seed (0 )
132
- cond_X = 5 + 5 * torch .rand (num_models , 2 , 4 , ** tkwargs )
132
+ cond_X = 5 + 5 * torch .rand (num_models , num_cond , 4 , ** tkwargs )
133
133
cond_Y = 10 + torch .sin (cond_X [..., :1 ])
134
- cond_Yvar = None if infer_noise else 0.1 * torch .ones (cond_Y .shape )
134
+ cond_Yvar = (
135
+ None if infer_noise else 0.1 * torch .ones (cond_Y .shape , ** tkwargs )
136
+ )
135
137
return cond_X , cond_Y , cond_Yvar
136
138
137
139
def _get_mcmc_samples (
@@ -667,11 +669,15 @@ def test_custom_pyro_model(self):
667
669
)
668
670
669
671
def test_condition_on_observation (self ):
670
-
672
+ # The following conditioned data shapes should work (output describes):
673
+ # training data shape after cond(batch shape in output is req. in gpytorch)
674
+ # X: num_models x n x d, Y: num_models x n x d --> num_models x n x d
675
+ # X: n x d, Y: n x d --> num_models x n x d
676
+ # X: n x d, Y: num_models x n x d --> num_models x n x d
671
677
num_models = 3
672
678
num_cond = 2
673
679
for infer_noise , dtype in itertools .product (
674
- (True ,), (torch .float , torch .double )
680
+ (True , False ), (torch .float , torch .double )
675
681
):
676
682
tkwargs = {"device" : self .device , "dtype" : dtype }
677
683
train_X , train_Y , train_Yvar , test_X = self ._get_unnormalized_data (
@@ -681,7 +687,10 @@ def test_condition_on_observation(self):
681
687
# condition on different observations per model to obtain num_models sets
682
688
# of training data
683
689
cond_X , cond_Y , cond_Yvar = self ._get_unnormalized_condition_data (
684
- num_models = num_models , infer_noise = infer_noise , ** tkwargs
690
+ num_models = num_models ,
691
+ num_cond = num_cond ,
692
+ infer_noise = infer_noise ,
693
+ ** tkwargs
685
694
)
686
695
model = SaasFullyBayesianSingleTaskGP (
687
696
train_X = train_X ,
@@ -712,8 +721,12 @@ def test_condition_on_observation(self):
712
721
cond_model .train_inputs [0 ].shape ,
713
722
torch .Size ([num_models , num_train + num_cond , num_dims ]),
714
723
)
724
+
725
+ # the batch shape of the condition model is added during conditioning
726
+ self .assertEqual (cond_model .batch_shape , torch .Size ([num_models ]))
727
+
715
728
# condition on identical sets of data (i.e. one set) for all models
716
- # i.e, with no batch shape. This should not work .
729
+ # i.e, with no batch shape. This infers the batch shape .
717
730
cond_X_nobatch , cond_Y_nobatch = cond_X [0 ], cond_Y [0 ]
718
731
model = SaasFullyBayesianSingleTaskGP (
719
732
train_X = train_X ,
@@ -728,14 +741,36 @@ def test_condition_on_observation(self):
728
741
)
729
742
model .load_mcmc_samples (mcmc_samples )
730
743
731
- # This should __NOT__ work - conditioning must have a batch size for the
732
- # conditioned point and is not supported (the training data by default
733
- # does not have a batch size)
744
+ # conditioning without a batch size - the resulting conditioned model
745
+ # will still have a batch size
734
746
model .posterior (train_X )
735
- with self .assertRaises (ValueError ):
736
- model .condition_on_observations (
737
- cond_X_nobatch , cond_Y_nobatch , noise = cond_Yvar
738
- )
747
+ cond_model = model .condition_on_observations (
748
+ cond_X_nobatch , cond_Y_nobatch , noise = cond_Yvar
749
+ )
750
+ self .assertEqual (
751
+ cond_model .train_inputs [0 ].shape ,
752
+ torch .Size ([num_models , num_train + num_cond , num_dims ]),
753
+ )
754
+
755
+ # test repeated conditining
756
+ repeat_cond_X = cond_X + 5
757
+ repeat_cond_model = cond_model .condition_on_observations (
758
+ repeat_cond_X , cond_Y , noise = cond_Yvar
759
+ )
760
+ self .assertEqual (
761
+ repeat_cond_model .train_inputs [0 ].shape ,
762
+ torch .Size ([num_models , num_train + 2 * num_cond , num_dims ]),
763
+ )
764
+
765
+ # test repeated conditioning without a batch size
766
+ repeat_cond_X_nobatch = cond_X_nobatch + 10
767
+ repeat_cond_model2 = repeat_cond_model .condition_on_observations (
768
+ repeat_cond_X_nobatch , cond_Y_nobatch , noise = cond_Yvar
769
+ )
770
+ self .assertEqual (
771
+ repeat_cond_model2 .train_inputs [0 ].shape ,
772
+ torch .Size ([num_models , num_train + 3 * num_cond , num_dims ]),
773
+ )
739
774
740
775
def test_bisect (self ):
741
776
def f (x ):
0 commit comments