88"""
99
1010
11- from typing import Any , Dict , List , Mapping , NoReturn , Optional , Tuple , Union
11+ from typing import Any , Dict , List , Mapping , NoReturn , Optional , Tuple
1212
1313import pyro
1414import torch
1515from botorch .acquisition .objective import PosteriorTransform
1616from botorch .models .fully_bayesian import (
1717 matern52_kernel ,
1818 MIN_INFERRED_NOISE_LEVEL ,
19- PyroModel ,
2019 reshape_and_detach ,
2120 SaasPyroModel ,
2221)
2322from botorch .models .multitask import MultiTaskGP
2423from botorch .models .transforms .input import InputTransform
2524from botorch .models .transforms .outcome import OutcomeTransform
2625from botorch .posteriors .fully_bayesian import GaussianMixturePosterior , MCMC_DIM
27- from botorch .utils .datasets import MultiTaskDataset , SupervisedDataset
2826from gpytorch .distributions .multivariate_normal import MultivariateNormal
2927from gpytorch .kernels import MaternKernel
3028from gpytorch .kernels .kernel import Kernel
@@ -200,9 +198,10 @@ def __init__(
200198 train_Yvar : Optional [Tensor ] = None ,
201199 output_tasks : Optional [List [int ]] = None ,
202200 rank : Optional [int ] = None ,
201+ all_tasks : Optional [List [int ]] = None ,
203202 outcome_transform : Optional [OutcomeTransform ] = None ,
204203 input_transform : Optional [InputTransform ] = None ,
205- pyro_model : Optional [PyroModel ] = None ,
204+ pyro_model : Optional [MultitaskSaasPyroModel ] = None ,
206205 ) -> None :
207206 r"""Initialize the fully Bayesian multi-task GP model.
208207
@@ -216,13 +215,15 @@ def __init__(
216215 outputs for. If omitted, return outputs for all task indices.
217216 rank: The num of learned task embeddings to be used in the task kernel.
218217 If omitted, use a full rank (i.e. number of tasks) kernel.
218+ all_tasks: NOT SUPPORTED!
219219 outcome_transform: An outcome transform that is applied to the
220220 training data during instantiation and to the posterior during
221221 inference (that is, the `Posterior` obtained by calling
222222 `.posterior` on the model will be on the original scale).
223223 input_transform: An input transform that is applied to the inputs `X`
224224 in the model's forward pass.
225- pyro_model: Optional `PyroModel`, defaults to `MultitaskSaasPyroModel`.
225+ pyro_model: Optional `PyroModel` that has the same signature as
226+ `MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
226227 """
227228 if not (
228229 train_X .ndim == train_Y .ndim == 2
@@ -253,6 +254,12 @@ def __init__(
253254 output_tasks = output_tasks ,
254255 rank = rank ,
255256 )
257+ if all_tasks is not None and self ._expected_task_values != set (all_tasks ):
258+ raise NotImplementedError (
259+ "The `all_tasks` argument is not supported by SAAS MTGP. "
260+ f"The training data includes tasks { self ._expected_task_values } , "
261+ f"got { all_tasks = } ."
262+ )
256263 self .to (train_X )
257264
258265 self .mean_module = None
@@ -383,29 +390,6 @@ def forward(self, X: Tensor) -> MultivariateNormal:
383390 covar = covar_x .mul (covar_i )
384391 return MultivariateNormal (mean_x , covar )
385392
386- @classmethod
387- def construct_inputs (
388- cls ,
389- training_data : Union [SupervisedDataset , MultiTaskDataset ],
390- task_feature : int ,
391- rank : Optional [int ] = None ,
392- ** kwargs : Any ,
393- ) -> Dict [str , Any ]:
394- r"""Construct `Model` keyword arguments from a dataset and other args.
395-
396- Args:
397- training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
398- task_feature: Column index of embedded task indicator features.
399- rank: The rank of the cross-task covariance matrix.
400- """
401- inputs = super ().construct_inputs (
402- training_data = training_data , task_feature = task_feature , rank = rank , ** kwargs
403- )
404- inputs .pop ("task_covar_prior" )
405- if "train_Yvar" not in inputs :
406- inputs ["train_Yvar" ] = None
407- return inputs
408-
409393 def load_state_dict (self , state_dict : Mapping [str , Any ], strict : bool = True ):
410394 r"""Custom logic for loading the state dict.
411395
0 commit comments