Skip to content

Commit 1e66b0d

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Refactor of MultiTask / FullyBayesianMultiTaskGP to use ProductKernel & IndexKernel (#2908)
Summary: X-link: facebook/Ax#3992 X-link: facebookexternal/botorch_fb#23 Modified MultiTask and FullyBayesianMultiTask to use IndexKernel instead of two different covar modules. For large matrices, this constitutes a significant speed-up (2-3x anecdotally) and an even larger memory decrease. In addition, this makes MultiTaskFBGP and SingleTaskFBGPs share a lot of code. I'll enable more code sharing between them in a subsequent diff. With some additional functionality in IndexKernel (i.e. structured learning of the covar_matrix elements), this change would apply to other MTGPs as well. NOTE: Providing negative indices to an IndexKernel is not supported: pytorch/pytorch#76347 Reviewed By: saitcakmak Differential Revision: D76317553
1 parent 223656a commit 1e66b0d

File tree

6 files changed

+170
-90
lines changed

6 files changed

+170
-90
lines changed

botorch/models/contextual_multioutput.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from botorch.models.multitask import MultiTaskGP
2020
from botorch.models.transforms.input import InputTransform
2121
from botorch.models.transforms.outcome import OutcomeTransform
22+
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
2223
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2324
from botorch.utils.types import _DefaultType, DEFAULT
2425
from gpytorch.constraints import Interval
26+
from gpytorch.distributions import MultivariateNormal
2527
from gpytorch.kernels.rbf_kernel import RBFKernel
2628
from gpytorch.likelihoods.likelihood import Likelihood
2729
from gpytorch.module import Module
@@ -107,6 +109,13 @@ def __init__(
107109
outcome_transform=outcome_transform,
108110
input_transform=input_transform,
109111
)
112+
# Overwriting the covar_module created in the parent class
113+
if covar_module is None:
114+
self.covar_module = get_covar_module_with_dim_scaled_prior(
115+
ard_num_dims=self.num_non_task_features
116+
)
117+
else:
118+
self.covar_module = covar_module
110119
self.device = train_X.device
111120
if all_tasks is None:
112121
all_tasks_tensor = train_X[:, task_feature].unique()
@@ -188,6 +197,10 @@ def task_covar_module(self, task_idcs: Tensor) -> Tensor:
188197
Returns:
189198
Task covariance matrix of shape (b x n x n).
190199
"""
200+
# NOTE: This can probably be re-written more efficiently using
201+
# IndexKernel (or an IndexKernel subclass) and the `evaluate_task_covar`
202+
# and then have the forward pass evaluate a ProductKernel of the two.
203+
191204
# This is a tensor of shape (num_tasks x num_tasks).
192205
covar_matrix = self._eval_context_covar().to_dense()
193206
# Here, we index into the base covar matrix to extract
@@ -208,6 +221,20 @@ def task_covar_module(self, task_idcs: Tensor) -> Tensor:
208221
covar_matrix[base_idx].transpose(-1, -2).gather(index=expanded_idx, dim=-2)
209222
)
210223

224+
def forward(self, x: Tensor) -> MultivariateNormal:
225+
if self.training:
226+
x = self.transform_inputs(x)
227+
x_basic_lead, task_idcs, x_basic_trail = self._split_inputs(x)
228+
x_basic = torch.cat([x_basic_lead, x_basic_trail], dim=-1)
229+
# Compute base mean and covariance
230+
mean_x = self.mean_module(x_basic)
231+
covar_x = self.covar_module(x_basic)
232+
# Compute task covariances
233+
covar_i = self.task_covar_module(task_idcs)
234+
# Combine the two in an ICM fashion
235+
covar = covar_x.mul(covar_i)
236+
return MultivariateNormal(mean_x, covar)
237+
211238
@classmethod
212239
def construct_inputs(
213240
cls,

botorch/models/fully_bayesian_multitask.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from botorch.acquisition.objective import PosteriorTransform
1515
from botorch.models.fully_bayesian import (
1616
matern52_kernel,
17+
MCMC_DIM,
1718
MIN_INFERRED_NOISE_LEVEL,
1819
reshape_and_detach,
1920
SaasPyroModel,
@@ -22,7 +23,7 @@
2223
from botorch.models.multitask import MultiTaskGP
2324
from botorch.models.transforms.input import InputTransform
2425
from botorch.models.transforms.outcome import OutcomeTransform
25-
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
26+
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
2627
from gpytorch.distributions import MultivariateNormal
2728
from gpytorch.kernels import MaternKernel
2829
from gpytorch.kernels.index_kernel import IndexKernel
@@ -66,6 +67,10 @@ def set_inputs(
6667
task_rank: The num of learned task embeddings to be used in the task kernel.
6768
If omitted, use a full rank (i.e. number of tasks) kernel.
6869
"""
70+
# NOTE PyTorch does not support negative indexing for tensors in index_select,
71+
# (https://github.com/pytorch/pytorch/issues/76347), so we have to make sure
72+
# that the task feature is positive.
73+
task_feature = task_feature % train_X.shape[-1]
6974
super().set_inputs(train_X, train_Y, train_Yvar)
7075
# obtain a list of task indicies
7176
all_tasks = train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
@@ -140,15 +145,19 @@ def load_mcmc_samples(
140145
num_mcmc_samples = len(mcmc_samples["mean"])
141146
batch_shape = torch.Size([num_mcmc_samples])
142147

143-
mean_module, covar_module, likelihood, _ = super().load_mcmc_samples(
148+
mean_module, data_covar_module, likelihood, _ = super().load_mcmc_samples(
144149
mcmc_samples=mcmc_samples
145150
)
151+
data_indices = torch.arange(self.train_X.shape[-1] - 1)
152+
data_indices[self.task_feature :] += 1 # exclude task feature
146153

154+
data_covar_module.active_dims = data_indices # .to(tkwargs["device"])
147155
latent_covar_module = MaternKernel(
148156
nu=2.5,
149157
ard_num_dims=self.task_rank,
150158
batch_shape=batch_shape,
151159
).to(**tkwargs)
160+
152161
latent_covar_module.lengthscale = reshape_and_detach(
153162
target=latent_covar_module.lengthscale,
154163
new_value=mcmc_samples["task_lengthscale"],
@@ -159,22 +168,27 @@ def load_mcmc_samples(
159168
num_tasks=self.num_tasks,
160169
rank=self.task_rank,
161170
batch_shape=latent_features.shape[:-2],
162-
).to(**tkwargs)
171+
active_dims=torch.tensor([self.task_feature]).to(tkwargs["device"]),
172+
)
163173
task_covar_module.covar_factor = Parameter(
164174
task_covar.cholesky().to_dense().detach()
165175
)
166176

167-
# NOTE: 'var' is implicitly assumed to be zero from the sampling procedure in
168-
# the FBMTGP model but not in the regular MTGP. I dont how if the var parameter
169-
# affects predictions in practice, but setting it to zero is consistent with the
170-
# previous implementation.
177+
# NOTE: The IndexKernel has a learnable 'var' parameter in addition to the
178+
# task covariances, corresponding do task-specific variances along the diagonal
179+
# of the task covariance matrix. As this parameter is not sampled in `sample()`
180+
# we implicitly assume it to be zero. This is consistent with the previous
181+
# SAASFBMTGP implementation, but means that the non-fully Bayesian and fully
182+
# Bayesian models run on slightly different task covar modules.
183+
184+
# We set the aforementioned task covar module var parameter to zero here.
171185
task_covar_module.var = torch.zeros_like(task_covar_module.var)
172-
return mean_module, covar_module, likelihood, task_covar_module
186+
covar_module = data_covar_module * task_covar_module
187+
return mean_module, covar_module, likelihood, None
173188

174189

175190
class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
176191
r"""A fully Bayesian multi-task GP model with the SAAS prior.
177-
178192
This model assumes that the inputs have been normalized to [0, 1]^d and that the
179193
output has been stratified standardized to have zero mean and unit variance for
180194
each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
@@ -286,8 +300,6 @@ def __init__(
286300
self.mean_module = None
287301
self.covar_module = None
288302
self.likelihood = None
289-
self.task_covar_module = None
290-
self.register_buffer("latent_features", None)
291303
if pyro_model is None:
292304
pyro_model = MultitaskSaasPyroModel()
293305
pyro_model.set_inputs(
@@ -321,21 +333,20 @@ def train(
321333
self.mean_module = None
322334
self.covar_module = None
323335
self.likelihood = None
324-
self.task_covar_module = None
325336
return self
326337

327338
@property
328339
def median_lengthscale(self) -> Tensor:
329340
r"""Median lengthscales across the MCMC samples."""
330341
self._check_if_fitted()
331-
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
342+
lengthscale = self.covar_module.kernels[0].base_kernel.lengthscale.clone()
332343
return lengthscale.median(0).values.squeeze(0)
333344

334345
@property
335346
def num_mcmc_samples(self) -> int:
336347
r"""Number of MCMC samples in the model."""
337348
self._check_if_fitted()
338-
return len(self.covar_module.outputscale)
349+
return self.covar_module.kernels[0].batch_shape[0]
339350

340351
@property
341352
def batch_shape(self) -> torch.Size:
@@ -367,7 +378,7 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
367378
self.mean_module,
368379
self.covar_module,
369380
self.likelihood,
370-
self.task_covar_module,
381+
_,
371382
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
372383

373384
def posterior(
@@ -438,7 +449,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
438449
self.mean_module,
439450
self.covar_module,
440451
self.likelihood,
441-
self.task_covar_module,
452+
_, # Possibly space for input transform
442453
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
443454
# Load the actual samples from the state dict
444455
super().load_state_dict(state_dict=state_dict, strict=strict)

botorch/models/multitask.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def __init__(
172172
X=train_X, input_transform=input_transform
173173
)
174174
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
175+
176+
# IndexKernel cannot work with negative task features, so we shift them to
177+
# be positive here.
178+
if task_feature < 0:
179+
task_feature += transformed_X.shape[-1]
175180
(
176181
all_tasks_inferred,
177182
task_feature,
@@ -220,16 +225,29 @@ def __init__(
220225
)
221226
self.mean_module = mean_module or ConstantMean()
222227
if covar_module is None:
223-
self.covar_module = get_covar_module_with_dim_scaled_prior(
224-
ard_num_dims=self.num_non_task_features
228+
data_covar_module = get_covar_module_with_dim_scaled_prior(
229+
ard_num_dims=self.num_non_task_features,
230+
active_dims=self._base_idxr,
225231
)
226232
else:
227-
self.covar_module = covar_module
233+
data_covar_module = covar_module
234+
# This check enables models which don't adhere to the convention (e.g.
235+
# adding additional feature dimensions, like HeteroMTGP) to be used.
236+
if covar_module.active_dims is None:
237+
# Since we no longer use the custom indexing which derived the
238+
# task indexing in the forward pass, we need to explicitly set
239+
# the active dims here to ensure that the forward pass works.
240+
data_covar_module.active_dims = self._base_idxr
228241

229242
self._rank = rank if rank is not None else self.num_tasks
230-
self.task_covar_module = IndexKernel(
231-
num_tasks=self.num_tasks, rank=self._rank, prior=task_covar_prior
243+
task_covar_module = IndexKernel(
244+
num_tasks=self.num_tasks,
245+
rank=self._rank,
246+
prior=task_covar_prior,
247+
active_dims=[task_feature],
232248
)
249+
250+
self.covar_module = data_covar_module * task_covar_module
233251
task_mapper = get_task_value_remapping(
234252
task_values=torch.tensor(
235253
all_tasks, dtype=torch.long, device=train_X.device
@@ -244,45 +262,41 @@ def __init__(
244262
self.outcome_transform = outcome_transform
245263
self.to(train_X)
246264

247-
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor]:
248-
r"""Extracts base features and task indices from input data.
265+
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
266+
r"""Extracts features before task feature, task indices, and features after
267+
the task feature.
249268
250269
Args:
251270
x: The full input tensor with trailing dimension of size `d + 1`.
252271
Should be of float/double data type.
253272
254273
Returns:
255-
2-element tuple containing
256-
257-
- A `q x d` or `b x q x d` (batch mode) tensor with trailing
258-
dimension made up of the `d` non-task-index columns of `x`, arranged
259-
in the order as specified by the indexer generated during model
260-
instantiation.
261-
- A `q` or `b x q` (batch mode) tensor of long data type containing
262-
the task indices.
274+
3-element tuple containing
275+
276+
- A `q x d` or `b x q x d` tensor with features before the task feature
277+
- A `q` or `b x q` tensor with mapped task indices
278+
- A `q x d` or `b x q x d` tensor with features after the task feature
263279
"""
264-
batch_shape, d = x.shape[:-2], x.shape[-1]
265-
x_basic = x[..., self._base_idxr].view(batch_shape + torch.Size([-1, d - 1]))
266-
task_idcs = (
267-
x[..., self._task_feature]
268-
.view(batch_shape + torch.Size([-1, 1]))
269-
.to(dtype=torch.long)
270-
)
271-
task_idcs = self._map_tasks(task_values=task_idcs)
272-
return x_basic, task_idcs
280+
batch_shape = x.shape[:-2]
281+
# Extract task indices and convert to long
282+
task_idcs = x[..., self._task_feature].view(batch_shape + torch.Size([-1, 1]))
283+
task_idcs = self._map_tasks(task_values=task_idcs.to(dtype=torch.long))
284+
285+
# Extract features before and after task feature
286+
x_before = x[..., : self._task_feature]
287+
x_after = x[..., (self._task_feature + 1) :]
288+
return x_before, task_idcs, x_after
273289

274290
def forward(self, x: Tensor) -> MultivariateNormal:
275291
if self.training:
276292
x = self.transform_inputs(x)
277-
x_basic, task_idcs = self._split_inputs(x)
278-
# Compute base mean and covariance
279-
mean_x = self.mean_module(x_basic)
280-
covar_x = self.covar_module(x_basic)
281-
# Compute task covariances
282-
covar_i = self.task_covar_module(task_idcs)
283-
# Combine the two in an ICM fashion
284-
covar = covar_x.mul(covar_i)
285-
return MultivariateNormal(mean_x, covar)
293+
294+
# Get features before task feature, task indices, and features after task the
295+
# feature, with the feature mapping applied to the task indices.
296+
x = torch.cat(self._split_inputs(x), dim=-1)
297+
mean_x = self.mean_module(x)
298+
covar_x = self.covar_module(x)
299+
return MultivariateNormal(mean_x, covar_x)
286300

287301
@classmethod
288302
def get_all_tasks(

test/models/test_contextual_multioutput.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from botorch.utils.test_helpers import gen_multi_task_dataset
1414
from botorch.utils.testing import BotorchTestCase
1515
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
16+
from gpytorch.kernels import MaternKernel
1617
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
1718
from linear_operator.operators import LinearOperator
1819
from linear_operator.operators.interpolated_linear_operator import (
@@ -101,6 +102,16 @@ def test_LCEMGP(self):
101102
right_interp_indices=task_idcs,
102103
).to_dense()
103104
self.assertAllClose(previous_covar, model.task_covar_module(task_idcs))
105+
custom_covar_module = MaternKernel()
106+
model_custom_covar = LCEMGP(
107+
train_X=train_x,
108+
train_Y=train_y,
109+
task_feature=task_feature,
110+
embs_dim_list=[2], # increase dim from 1 to 2
111+
context_emb_feature=torch.tensor([[0.2], [0.3]]),
112+
covar_module=custom_covar_module,
113+
)
114+
self.assertIsInstance(model_custom_covar.covar_module, MaternKernel)
104115

105116
def test_construct_inputs(self) -> None:
106117
for with_embedding_inputs, yvar, skip_task_features_in_datasets in zip(

0 commit comments

Comments
 (0)