1414from  botorch .acquisition .objective  import  PosteriorTransform 
1515from  botorch .models .fully_bayesian  import  (
1616    matern52_kernel ,
17+     MCMC_DIM ,
1718    MIN_INFERRED_NOISE_LEVEL ,
1819    reshape_and_detach ,
1920    SaasPyroModel ,
2223from  botorch .models .multitask  import  MultiTaskGP 
2324from  botorch .models .transforms .input  import  InputTransform 
2425from  botorch .models .transforms .outcome  import  OutcomeTransform 
25- from  botorch .posteriors .fully_bayesian  import  GaussianMixturePosterior ,  MCMC_DIM 
26+ from  botorch .posteriors .fully_bayesian  import  GaussianMixturePosterior 
2627from  gpytorch .distributions  import  MultivariateNormal 
2728from  gpytorch .kernels  import  MaternKernel 
2829from  gpytorch .kernels .index_kernel  import  IndexKernel 
@@ -66,6 +67,10 @@ def set_inputs(
6667            task_rank: The num of learned task embeddings to be used in the task kernel. 
6768                If omitted, use a full rank (i.e. number of tasks) kernel. 
6869        """ 
70+         # NOTE PyTorch does not support negative indexing for tensors in index_select, 
71+         # (https://github.com/pytorch/pytorch/issues/76347), so we have to make sure 
72+         # that the task feature is positive. 
73+         task_feature  =  task_feature  %  train_X .shape [- 1 ]
6974        super ().set_inputs (train_X , train_Y , train_Yvar )
7075        # obtain a list of task indicies 
7176        all_tasks  =  train_X [:, task_feature ].unique ().to (dtype = torch .long ).tolist ()
@@ -140,15 +145,19 @@ def load_mcmc_samples(
140145        num_mcmc_samples  =  len (mcmc_samples ["mean" ])
141146        batch_shape  =  torch .Size ([num_mcmc_samples ])
142147
143-         mean_module , covar_module , likelihood , _  =  super ().load_mcmc_samples (
148+         mean_module , data_covar_module , likelihood , _  =  super ().load_mcmc_samples (
144149            mcmc_samples = mcmc_samples 
145150        )
151+         data_indices  =  torch .arange (self .train_X .shape [- 1 ] -  1 )
152+         data_indices [self .task_feature  :] +=  1   # exclude task feature 
146153
154+         data_covar_module .active_dims  =  data_indices   # .to(tkwargs["device"]) 
147155        latent_covar_module  =  MaternKernel (
148156            nu = 2.5 ,
149157            ard_num_dims = self .task_rank ,
150158            batch_shape = batch_shape ,
151159        ).to (** tkwargs )
160+ 
152161        latent_covar_module .lengthscale  =  reshape_and_detach (
153162            target = latent_covar_module .lengthscale ,
154163            new_value = mcmc_samples ["task_lengthscale" ],
@@ -159,22 +168,27 @@ def load_mcmc_samples(
159168            num_tasks = self .num_tasks ,
160169            rank = self .task_rank ,
161170            batch_shape = latent_features .shape [:- 2 ],
162-         ).to (** tkwargs )
171+             active_dims = torch .tensor ([self .task_feature ]).to (tkwargs ["device" ]),
172+         )
163173        task_covar_module .covar_factor  =  Parameter (
164174            task_covar .cholesky ().to_dense ().detach ()
165175        )
166176
167-         # NOTE: 'var' is implicitly assumed to be zero from the sampling procedure in 
168-         # the FBMTGP model but not in the regular MTGP. I dont how if the var parameter 
169-         # affects predictions in practice, but setting it to zero is consistent with the 
170-         # previous implementation. 
177+         # NOTE: The IndexKernel has a learnable 'var' parameter in addition to the 
178+         # task covariances, corresponding do task-specific variances along the diagonal 
179+         # of the task covariance matrix. As this parameter is not sampled in `sample()` 
180+         # we implicitly assume it to be zero. This is consistent with the previous 
181+         # SAASFBMTGP implementation, but means that the non-fully Bayesian and fully 
182+         # Bayesian models run on slightly different task covar modules. 
183+ 
184+         # We set the aforementioned task covar module var parameter to zero here. 
171185        task_covar_module .var  =  torch .zeros_like (task_covar_module .var )
172-         return  mean_module , covar_module , likelihood , task_covar_module 
186+         covar_module  =  data_covar_module  *  task_covar_module 
187+         return  mean_module , covar_module , likelihood , None 
173188
174189
175190class  SaasFullyBayesianMultiTaskGP (MultiTaskGP ):
176191    r"""A fully Bayesian multi-task GP model with the SAAS prior. 
177- 
178192    This model assumes that the inputs have been normalized to [0, 1]^d and that the 
179193    output has been stratified standardized to have zero mean and unit variance for 
180194    each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data 
@@ -286,8 +300,6 @@ def __init__(
286300        self .mean_module  =  None 
287301        self .covar_module  =  None 
288302        self .likelihood  =  None 
289-         self .task_covar_module  =  None 
290-         self .register_buffer ("latent_features" , None )
291303        if  pyro_model  is  None :
292304            pyro_model  =  MultitaskSaasPyroModel ()
293305        pyro_model .set_inputs (
@@ -321,21 +333,20 @@ def train(
321333            self .mean_module  =  None 
322334            self .covar_module  =  None 
323335            self .likelihood  =  None 
324-             self .task_covar_module  =  None 
325336        return  self 
326337
327338    @property  
328339    def  median_lengthscale (self ) ->  Tensor :
329340        r"""Median lengthscales across the MCMC samples.""" 
330341        self ._check_if_fitted ()
331-         lengthscale  =  self .covar_module .base_kernel .lengthscale .clone ()
342+         lengthscale  =  self .covar_module .kernels [ 0 ]. base_kernel .lengthscale .clone ()
332343        return  lengthscale .median (0 ).values .squeeze (0 )
333344
334345    @property  
335346    def  num_mcmc_samples (self ) ->  int :
336347        r"""Number of MCMC samples in the model.""" 
337348        self ._check_if_fitted ()
338-         return  len ( self .covar_module .outputscale ) 
349+         return  self .covar_module .kernels [ 0 ]. batch_shape [ 0 ] 
339350
340351    @property  
341352    def  batch_shape (self ) ->  torch .Size :
@@ -367,7 +378,7 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
367378            self .mean_module ,
368379            self .covar_module ,
369380            self .likelihood ,
370-             self . task_covar_module ,
381+             _ ,
371382        ) =  self .pyro_model .load_mcmc_samples (mcmc_samples = mcmc_samples )
372383
373384    def  posterior (
@@ -438,7 +449,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
438449            self .mean_module ,
439450            self .covar_module ,
440451            self .likelihood ,
441-             self . task_covar_module , 
452+             _ ,   # Possibly space for input transform 
442453        ) =  self .pyro_model .load_mcmc_samples (mcmc_samples = mcmc_samples )
443454        # Load the actual samples from the state dict 
444455        super ().load_state_dict (state_dict = state_dict , strict = strict )
0 commit comments