@@ -124,6 +124,16 @@ def _get_unnormalized_data(self, infer_noise: bool, **tkwargs):
124
124
)
125
125
return train_X , train_Y , train_Yvar , test_X
126
126
127
+ def _get_unnormalized_condition_data (
128
+ self , num_models : int , infer_noise : bool , ** tkwargs
129
+ ):
130
+ with torch .random .fork_rng ():
131
+ torch .manual_seed (0 )
132
+ cond_X = 5 + 5 * torch .rand (num_models , 2 , 4 , ** tkwargs )
133
+ cond_Y = 10 + torch .sin (cond_X [..., :1 ])
134
+ cond_Yvar = None if infer_noise else 0.1 * torch .ones (cond_Y .shape )
135
+ return cond_X , cond_Y , cond_Yvar
136
+
127
137
def _get_mcmc_samples (
128
138
self , num_samples : int , dim : int , infer_noise : bool , ** tkwargs
129
139
):
@@ -656,6 +666,77 @@ def test_custom_pyro_model(self):
656
666
atol = 5e-4 ,
657
667
)
658
668
669
+ def test_condition_on_observation (self ):
670
+
671
+ num_models = 3
672
+ num_cond = 2
673
+ for infer_noise , dtype in itertools .product (
674
+ (True ,), (torch .float , torch .double )
675
+ ):
676
+ tkwargs = {"device" : self .device , "dtype" : dtype }
677
+ train_X , train_Y , train_Yvar , test_X = self ._get_unnormalized_data (
678
+ infer_noise = infer_noise , ** tkwargs
679
+ )
680
+ num_train , num_dims = train_X .shape
681
+ # condition on different observations per model to obtain num_models sets
682
+ # of training data
683
+ cond_X , cond_Y , cond_Yvar = self ._get_unnormalized_condition_data (
684
+ num_models = num_models , infer_noise = infer_noise , ** tkwargs
685
+ )
686
+ model = SaasFullyBayesianSingleTaskGP (
687
+ train_X = train_X ,
688
+ train_Y = train_Y ,
689
+ train_Yvar = train_Yvar ,
690
+ )
691
+ mcmc_samples = self ._get_mcmc_samples (
692
+ num_samples = num_models ,
693
+ dim = train_X .shape [- 1 ],
694
+ infer_noise = infer_noise ,
695
+ ** tkwargs
696
+ )
697
+ model .load_mcmc_samples (mcmc_samples )
698
+
699
+ # need to forward pass before conditioning
700
+ model .posterior (train_X )
701
+ cond_model = model .condition_on_observations (
702
+ cond_X , cond_Y , noise = cond_Yvar
703
+ )
704
+ posterior = cond_model .posterior (test_X )
705
+ self .assertEqual (
706
+ posterior .mean .shape , torch .Size ([num_models , len (test_X ), 1 ])
707
+ )
708
+
709
+ # since the data is not equal for the conditioned points, a batch size
710
+ # is added to the training data
711
+ self .assertEqual (
712
+ cond_model .train_inputs [0 ].shape ,
713
+ torch .Size ([num_models , num_train + num_cond , num_dims ]),
714
+ )
715
+ # condition on identical sets of data (i.e. one set) for all models
716
+ # i.e, with no batch shape. This should not work.
717
+ cond_X_nobatch , cond_Y_nobatch = cond_X [0 ], cond_Y [0 ]
718
+ model = SaasFullyBayesianSingleTaskGP (
719
+ train_X = train_X ,
720
+ train_Y = train_Y ,
721
+ train_Yvar = train_Yvar ,
722
+ )
723
+ mcmc_samples = self ._get_mcmc_samples (
724
+ num_samples = num_models ,
725
+ dim = train_X .shape [- 1 ],
726
+ infer_noise = infer_noise ,
727
+ ** tkwargs
728
+ )
729
+ model .load_mcmc_samples (mcmc_samples )
730
+
731
+ # This should __NOT__ work - conditioning must have a batch size for the
732
+ # conditioned point and is not supported (the training data by default
733
+ # does not have a batch size)
734
+ model .posterior (train_X )
735
+ with self .assertRaises (ValueError ):
736
+ model .condition_on_observations (
737
+ cond_X_nobatch , cond_Y_nobatch , noise = cond_Yvar
738
+ )
739
+
659
740
def test_bisect (self ):
660
741
def f (x ):
661
742
return 1 + x
0 commit comments