1414from botorch .acquisition .objective import PosteriorTransform
1515from botorch .models .fully_bayesian import (
1616 matern52_kernel ,
17+ MCMC_DIM ,
1718 MIN_INFERRED_NOISE_LEVEL ,
1819 reshape_and_detach ,
1920 SaasPyroModel ,
2223from botorch .models .multitask import MultiTaskGP
2324from botorch .models .transforms .input import InputTransform
2425from botorch .models .transforms .outcome import OutcomeTransform
25- from botorch .posteriors .fully_bayesian import GaussianMixturePosterior , MCMC_DIM
26+ from botorch .posteriors .fully_bayesian import GaussianMixturePosterior
2627from gpytorch .distributions import MultivariateNormal
2728from gpytorch .kernels import MaternKernel
2829from gpytorch .kernels .index_kernel import IndexKernel
@@ -66,6 +67,10 @@ def set_inputs(
6667 task_rank: The num of learned task embeddings to be used in the task kernel.
6768 If omitted, use a full rank (i.e. number of tasks) kernel.
6869 """
70+ # NOTE PyTorch does not support negative indexing for tensors in index_select,
71+ # (https://github.com/pytorch/pytorch/issues/76347), so we have to make sure
72+ # that the task feature is positive.
73+ task_feature = task_feature % train_X .shape [- 1 ]
6974 super ().set_inputs (train_X , train_Y , train_Yvar )
7075 # obtain a list of task indicies
7176 all_tasks = train_X [:, task_feature ].unique ().to (dtype = torch .long ).tolist ()
@@ -140,15 +145,19 @@ def load_mcmc_samples(
140145 num_mcmc_samples = len (mcmc_samples ["mean" ])
141146 batch_shape = torch .Size ([num_mcmc_samples ])
142147
143- mean_module , covar_module , likelihood , _ = super ().load_mcmc_samples (
148+ mean_module , data_covar_module , likelihood , _ = super ().load_mcmc_samples (
144149 mcmc_samples = mcmc_samples
145150 )
151+ data_indices = torch .arange (self .train_X .shape [- 1 ] - 1 )
152+ data_indices [self .task_feature :] += 1 # exclude task feature
146153
154+ data_covar_module .active_dims = data_indices
147155 latent_covar_module = MaternKernel (
148156 nu = 2.5 ,
149157 ard_num_dims = self .task_rank ,
150158 batch_shape = batch_shape ,
151159 ).to (** tkwargs )
160+
152161 latent_covar_module .lengthscale = reshape_and_detach (
153162 target = latent_covar_module .lengthscale ,
154163 new_value = mcmc_samples ["task_lengthscale" ],
@@ -159,22 +168,27 @@ def load_mcmc_samples(
159168 num_tasks = self .num_tasks ,
160169 rank = self .task_rank ,
161170 batch_shape = latent_features .shape [:- 2 ],
162- ).to (** tkwargs )
171+ active_dims = torch .tensor ([self .task_feature ]),
172+ )
163173 task_covar_module .covar_factor = Parameter (
164174 task_covar .cholesky ().to_dense ().detach ()
165175 )
166-
167- # NOTE: 'var' is implicitly assumed to be zero from the sampling procedure in
168- # the FBMTGP model but not in the regular MTGP. I dont how if the var parameter
169- # affects predictions in practice, but setting it to zero is consistent with the
170- # previous implementation.
176+ task_covar_module = task_covar_module .to (** tkwargs )
177+ # NOTE: The IndexKernel has a learnable 'var' parameter in addition to the
178+ # task covariances, corresponding do task-specific variances along the diagonal
179+ # of the task covariance matrix. As this parameter is not sampled in `sample()`
180+ # we implicitly assume it to be zero. This is consistent with the previous
181+ # SAASFBMTGP implementation, but means that the non-fully Bayesian and fully
182+ # Bayesian models run on slightly different task covar modules.
183+
184+ # We set the aforementioned task covar module var parameter to zero here.
171185 task_covar_module .var = torch .zeros_like (task_covar_module .var )
172- return mean_module , covar_module , likelihood , task_covar_module
186+ covar_module = data_covar_module * task_covar_module
187+ return mean_module , covar_module , likelihood , None
173188
174189
175190class SaasFullyBayesianMultiTaskGP (MultiTaskGP ):
176191 r"""A fully Bayesian multi-task GP model with the SAAS prior.
177-
178192 This model assumes that the inputs have been normalized to [0, 1]^d and that the
179193 output has been stratified standardized to have zero mean and unit variance for
180194 each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
@@ -286,8 +300,6 @@ def __init__(
286300 self .mean_module = None
287301 self .covar_module = None
288302 self .likelihood = None
289- self .task_covar_module = None
290- self .register_buffer ("latent_features" , None )
291303 if pyro_model is None :
292304 pyro_model = MultitaskSaasPyroModel ()
293305 pyro_model .set_inputs (
@@ -321,21 +333,20 @@ def train(
321333 self .mean_module = None
322334 self .covar_module = None
323335 self .likelihood = None
324- self .task_covar_module = None
325336 return self
326337
327338 @property
328339 def median_lengthscale (self ) -> Tensor :
329340 r"""Median lengthscales across the MCMC samples."""
330341 self ._check_if_fitted ()
331- lengthscale = self .covar_module .base_kernel .lengthscale .clone ()
342+ lengthscale = self .covar_module .kernels [ 0 ]. base_kernel .lengthscale .clone ()
332343 return lengthscale .median (0 ).values .squeeze (0 )
333344
334345 @property
335346 def num_mcmc_samples (self ) -> int :
336347 r"""Number of MCMC samples in the model."""
337348 self ._check_if_fitted ()
338- return len ( self .covar_module .outputscale )
349+ return self .covar_module .kernels [ 0 ]. batch_shape [ 0 ]
339350
340351 @property
341352 def batch_shape (self ) -> torch .Size :
@@ -367,7 +378,7 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
367378 self .mean_module ,
368379 self .covar_module ,
369380 self .likelihood ,
370- self . task_covar_module ,
381+ _ ,
371382 ) = self .pyro_model .load_mcmc_samples (mcmc_samples = mcmc_samples )
372383
373384 def posterior (
@@ -438,7 +449,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
438449 self .mean_module ,
439450 self .covar_module ,
440451 self .likelihood ,
441- self . task_covar_module ,
452+ _ , # Possibly space for input transform
442453 ) = self .pyro_model .load_mcmc_samples (mcmc_samples = mcmc_samples )
443454 # Load the actual samples from the state dict
444455 super ().load_state_dict (state_dict = state_dict , strict = strict )
0 commit comments