Skip to content

Commit c9966e9

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Support custom all_tasks for MTGPs (#2271)
Summary: Pull Request resolved: #2271 This allows creation of MTGPs that support inference from tasks that don't appear in the training data. See #2265 for some discussion on how the task covariance behaves in the absence of task specific data. Reviewed By: esantorella Differential Revision: D53029681 fbshipit-source-id: 3df8c910ff03c828ee0a317c29a5067d90e7f769
1 parent 404d869 commit c9966e9

File tree

4 files changed

+55
-44
lines changed

4 files changed

+55
-44
lines changed

botorch/models/contextual_multioutput.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
context_emb_feature: Optional[Tensor] = None,
4646
embs_dim_list: Optional[List[int]] = None,
4747
output_tasks: Optional[List[int]] = None,
48+
all_tasks: Optional[List[int]] = None,
4849
input_transform: Optional[InputTransform] = None,
4950
outcome_transform: Optional[OutcomeTransform] = None,
5051
) -> None:
@@ -67,22 +68,31 @@ def __init__(
6768
for each categorical variable.
6869
output_tasks: A list of task indices for which to compute model
6970
outputs for. If omitted, return outputs for all task indices.
70-
71+
all_tasks: By default, multi-task GPs infer the list of all tasks from
72+
the task features in `train_X`. This is an experimental feature that
73+
enables creation of multi-task GPs with tasks that don't appear in the
74+
training data. Note that when a task is not observed, the corresponding
75+
task covariance will heavily depend on random initialization and may
76+
behave unexpectedly.
7177
"""
7278
super().__init__(
7379
train_X=train_X,
7480
train_Y=train_Y,
7581
task_feature=task_feature,
7682
train_Yvar=train_Yvar,
7783
output_tasks=output_tasks,
84+
all_tasks=all_tasks,
7885
input_transform=input_transform,
7986
outcome_transform=outcome_transform,
8087
)
8188
self.device = train_X.device
82-
# context indices
83-
all_tasks = train_X[:, task_feature].unique()
84-
self.all_tasks = all_tasks.to(dtype=torch.long).tolist()
85-
self.all_tasks.sort() # unique in python does automatic sort; add for safety
89+
if all_tasks is None:
90+
all_tasks = train_X[:, task_feature].unique()
91+
self.all_tasks = all_tasks.to(dtype=torch.long).tolist()
92+
else:
93+
all_tasks = torch.tensor(all_tasks, dtype=torch.long)
94+
self.all_tasks = all_tasks
95+
self.all_tasks.sort() # These are the context indices.
8696

8797
if context_cat_feature is None:
8898
context_cat_feature = all_tasks.unsqueeze(-1).to(device=self.device)

botorch/models/fully_bayesian_multitask.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,21 @@
88
"""
99

1010

11-
from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union
11+
from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple
1212

1313
import pyro
1414
import torch
1515
from botorch.acquisition.objective import PosteriorTransform
1616
from botorch.models.fully_bayesian import (
1717
matern52_kernel,
1818
MIN_INFERRED_NOISE_LEVEL,
19-
PyroModel,
2019
reshape_and_detach,
2120
SaasPyroModel,
2221
)
2322
from botorch.models.multitask import MultiTaskGP
2423
from botorch.models.transforms.input import InputTransform
2524
from botorch.models.transforms.outcome import OutcomeTransform
2625
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
27-
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2826
from gpytorch.distributions.multivariate_normal import MultivariateNormal
2927
from gpytorch.kernels import MaternKernel
3028
from gpytorch.kernels.kernel import Kernel
@@ -200,9 +198,10 @@ def __init__(
200198
train_Yvar: Optional[Tensor] = None,
201199
output_tasks: Optional[List[int]] = None,
202200
rank: Optional[int] = None,
201+
all_tasks: Optional[List[int]] = None,
203202
outcome_transform: Optional[OutcomeTransform] = None,
204203
input_transform: Optional[InputTransform] = None,
205-
pyro_model: Optional[PyroModel] = None,
204+
pyro_model: Optional[MultitaskSaasPyroModel] = None,
206205
) -> None:
207206
r"""Initialize the fully Bayesian multi-task GP model.
208207
@@ -216,13 +215,15 @@ def __init__(
216215
outputs for. If omitted, return outputs for all task indices.
217216
rank: The num of learned task embeddings to be used in the task kernel.
218217
If omitted, use a full rank (i.e. number of tasks) kernel.
218+
all_tasks: NOT SUPPORTED!
219219
outcome_transform: An outcome transform that is applied to the
220220
training data during instantiation and to the posterior during
221221
inference (that is, the `Posterior` obtained by calling
222222
`.posterior` on the model will be on the original scale).
223223
input_transform: An input transform that is applied to the inputs `X`
224224
in the model's forward pass.
225-
pyro_model: Optional `PyroModel`, defaults to `MultitaskSaasPyroModel`.
225+
pyro_model: Optional `PyroModel` that has the same signature as
226+
`MultitaskSaasPyroModel`. Defaults to `MultitaskSaasPyroModel`.
226227
"""
227228
if not (
228229
train_X.ndim == train_Y.ndim == 2
@@ -253,6 +254,12 @@ def __init__(
253254
output_tasks=output_tasks,
254255
rank=rank,
255256
)
257+
if all_tasks is not None and self._expected_task_values != set(all_tasks):
258+
raise NotImplementedError(
259+
"The `all_tasks` argument is not supported by SAAS MTGP. "
260+
f"The training data includes tasks {self._expected_task_values}, "
261+
f"got {all_tasks=}."
262+
)
256263
self.to(train_X)
257264

258265
self.mean_module = None
@@ -383,29 +390,6 @@ def forward(self, X: Tensor) -> MultivariateNormal:
383390
covar = covar_x.mul(covar_i)
384391
return MultivariateNormal(mean_x, covar)
385392

386-
@classmethod
387-
def construct_inputs(
388-
cls,
389-
training_data: Union[SupervisedDataset, MultiTaskDataset],
390-
task_feature: int,
391-
rank: Optional[int] = None,
392-
**kwargs: Any,
393-
) -> Dict[str, Any]:
394-
r"""Construct `Model` keyword arguments from a dataset and other args.
395-
396-
Args:
397-
training_data: A `SupervisedDataset` or a `MultiTaskDataset`.
398-
task_feature: Column index of embedded task indicator features.
399-
rank: The rank of the cross-task covariance matrix.
400-
"""
401-
inputs = super().construct_inputs(
402-
training_data=training_data, task_feature=task_feature, rank=rank, **kwargs
403-
)
404-
inputs.pop("task_covar_prior")
405-
if "train_Yvar" not in inputs:
406-
inputs["train_Yvar"] = None
407-
return inputs
408-
409393
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
410394
r"""Custom logic for loading the state dict.
411395

botorch/models/multitask.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
task_covar_prior: Optional[Prior] = None,
150150
output_tasks: Optional[List[int]] = None,
151151
rank: Optional[int] = None,
152+
all_tasks: Optional[List[int]] = None,
152153
input_transform: Optional[InputTransform] = None,
153154
outcome_transform: Optional[OutcomeTransform] = None,
154155
) -> None:
@@ -176,6 +177,12 @@ def __init__(
176177
full rank (i.e. number of tasks) kernel.
177178
task_covar_prior : A Prior on the task covariance matrix. Must operate
178179
on p.s.d. matrices. A common prior for this is the `LKJ` prior.
180+
all_tasks: By default, multi-task GPs infer the list of all tasks from
181+
the task features in `train_X`. This is an experimental feature that
182+
enables creation of multi-task GPs with tasks that don't appear in the
183+
training data. Note that when a task is not observed, the corresponding
184+
task covariance will heavily depend on random initialization and may
185+
behave unexpectedly.
179186
input_transform: An input transform that is applied in the model's
180187
forward pass.
181188
outcome_transform: An outcome transform that is applied to the
@@ -197,9 +204,12 @@ def __init__(
197204
X=train_X, input_transform=input_transform
198205
)
199206
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
200-
all_tasks, task_feature, self.num_non_task_features = self.get_all_tasks(
201-
transformed_X, task_feature, output_tasks
202-
)
207+
(
208+
all_tasks_inferred,
209+
task_feature,
210+
self.num_non_task_features,
211+
) = self.get_all_tasks(transformed_X, task_feature, output_tasks)
212+
all_tasks = all_tasks or all_tasks_inferred
203213
self.num_tasks = len(all_tasks)
204214
if outcome_transform is not None:
205215
train_Y, train_Yvar = outcome_transform(Y=train_Y, Yvar=train_Yvar)
@@ -360,13 +370,16 @@ def construct_inputs(
360370
base_inputs = super().construct_inputs(
361371
training_data=training_data, task_feature=task_feature, **kwargs
362372
)
363-
return {
364-
**base_inputs,
365-
"task_feature": task_feature,
366-
"output_tasks": output_tasks,
367-
"task_covar_prior": task_covar_prior,
368-
"rank": rank,
369-
}
373+
if isinstance(training_data, MultiTaskDataset):
374+
all_tasks = list(range(len(training_data.datasets)))
375+
base_inputs["all_tasks"] = all_tasks
376+
if task_covar_prior is not None:
377+
base_inputs["task_covar_prior"] = task_covar_prior
378+
if rank is not None:
379+
base_inputs["rank"] = rank
380+
base_inputs["task_feature"] = task_feature
381+
base_inputs["output_tasks"] = output_tasks
382+
return base_inputs
370383

371384

372385
class FixedNoiseMultiTaskGP(MultiTaskGP):
@@ -428,6 +441,7 @@ def __init__(
428441
"When `train_Yvar` is specified, `MultiTaskGP` behaves the same "
429442
"as the `FixedNoiseMultiTaskGP`.",
430443
DeprecationWarning,
444+
stacklevel=2,
431445
)
432446
super().__init__(
433447
train_X=train_X,

test/models/test_fully_bayesian_multitask.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,10 @@ def test_construct_inputs(self):
583583
)
584584
self.assertTrue(torch.equal(data_dict["train_X"], train_X))
585585
self.assertTrue(torch.equal(data_dict["train_Y"], train_Y))
586-
self.assertAllClose(data_dict["train_Yvar"], train_Yvar)
586+
if train_Yvar is not None:
587+
self.assertAllClose(data_dict["train_Yvar"], train_Yvar)
588+
else:
589+
self.assertNotIn("train_Yvar", data_dict)
587590
self.assertEqual(data_dict["task_feature"], task_feature)
588591
self.assertEqual(data_dict["rank"], 1)
589592
self.assertTrue("task_covar_prior" not in data_dict)

0 commit comments

Comments
 (0)