Skip to content

Commit e644dee

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Refactor of MultiTask / FullyBayesianMultiTaskGP to use ProcuctKernel & IndexKernel
Summary: Modifed MultiTask and FullyBayesianMultiTask to use IndexKernel instead of two different covar modules. For large matrices, this constitutes a significant speed-up (2-3x anecdotally) and a seemingly even larger memory decrease. In addition, this makes MultiTaskFBGP and SingleTaskFBGPs share a lot of code. I'll enable more code sharing between them in a subsequent diff. With some additional functionality in IndexKernel (i.e. structured learning of the covar_matrix elements), this change would apply to other MTGPs as well. NOTE: Providing negative indices to an IndexKernel is not supported: pytorch/pytorch#76347 Differential Revision: D76317553
1 parent d3d3c6f commit e644dee

File tree

5 files changed

+147
-86
lines changed

5 files changed

+147
-86
lines changed

botorch/models/contextual_multioutput.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from botorch.models.multitask import MultiTaskGP
2020
from botorch.models.transforms.input import InputTransform
2121
from botorch.models.transforms.outcome import OutcomeTransform
22+
from botorch.models.utils.gpytorch_modules import get_covar_module_with_dim_scaled_prior
2223
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
2324
from botorch.utils.types import _DefaultType, DEFAULT
2425
from gpytorch.constraints import Interval
26+
from gpytorch.distributions import MultivariateNormal
2527
from gpytorch.kernels.rbf_kernel import RBFKernel
2628
from gpytorch.likelihoods.likelihood import Likelihood
2729
from gpytorch.module import Module
@@ -107,6 +109,13 @@ def __init__(
107109
outcome_transform=outcome_transform,
108110
input_transform=input_transform,
109111
)
112+
# Overwriting the covar_module created in the parent class
113+
if covar_module is None:
114+
self.covar_module = get_covar_module_with_dim_scaled_prior(
115+
ard_num_dims=self.num_non_task_features
116+
)
117+
else:
118+
self.covar_module = covar_module
110119
self.device = train_X.device
111120
if all_tasks is None:
112121
all_tasks_tensor = train_X[:, task_feature].unique()
@@ -188,6 +197,10 @@ def task_covar_module(self, task_idcs: Tensor) -> Tensor:
188197
Returns:
189198
Task covariance matrix of shape (b x n x n).
190199
"""
200+
# NOTE: This can probably be re-written more efficiently using
201+
# IndexKernel (or an IndexKernel subclass) and the `evaluate_task_covar`
202+
# and then have the forward pass evaluate a ProductKernel of the two.
203+
191204
# This is a tensor of shape (num_tasks x num_tasks).
192205
covar_matrix = self._eval_context_covar().to_dense()
193206
# Here, we index into the base covar matrix to extract
@@ -208,6 +221,20 @@ def task_covar_module(self, task_idcs: Tensor) -> Tensor:
208221
covar_matrix[base_idx].transpose(-1, -2).gather(index=expanded_idx, dim=-2)
209222
)
210223

224+
def forward(self, x: Tensor) -> MultivariateNormal:
225+
if self.training:
226+
x = self.transform_inputs(x)
227+
x_basic_lead, task_idcs, x_basic_trail = self._split_inputs(x)
228+
x_basic = torch.cat([x_basic_lead, x_basic_trail], dim=-1)
229+
# Compute base mean and covariance
230+
mean_x = self.mean_module(x_basic)
231+
covar_x = self.covar_module(x_basic)
232+
# Compute task covariances
233+
covar_i = self.task_covar_module(task_idcs)
234+
# Combine the two in an ICM fashion
235+
covar = covar_x.mul(covar_i)
236+
return MultivariateNormal(mean_x, covar)
237+
211238
@classmethod
212239
def construct_inputs(
213240
cls,

botorch/models/fully_bayesian_multitask.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from botorch.acquisition.objective import PosteriorTransform
1515
from botorch.models.fully_bayesian import (
1616
matern52_kernel,
17+
MCMC_DIM,
1718
MIN_INFERRED_NOISE_LEVEL,
1819
reshape_and_detach,
1920
SaasPyroModel,
@@ -22,7 +23,7 @@
2223
from botorch.models.multitask import MultiTaskGP
2324
from botorch.models.transforms.input import InputTransform
2425
from botorch.models.transforms.outcome import OutcomeTransform
25-
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior, MCMC_DIM
26+
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
2627
from gpytorch.distributions import MultivariateNormal
2728
from gpytorch.kernels import MaternKernel
2829
from gpytorch.kernels.index_kernel import IndexKernel
@@ -66,6 +67,8 @@ def set_inputs(
6667
task_rank: The num of learned task embeddings to be used in the task kernel.
6768
If omitted, use a full rank (i.e. number of tasks) kernel.
6869
"""
70+
if task_feature < 0:
71+
task_feature += train_X.shape[-1]
6972
super().set_inputs(train_X, train_Y, train_Yvar)
7073
# obtain a list of task indicies
7174
all_tasks = train_X[:, task_feature].unique().to(dtype=torch.long).tolist()
@@ -140,15 +143,18 @@ def load_mcmc_samples(
140143
num_mcmc_samples = len(mcmc_samples["mean"])
141144
batch_shape = torch.Size([num_mcmc_samples])
142145

143-
mean_module, covar_module, likelihood, _ = super().load_mcmc_samples(
146+
mean_module, data_covar_module, likelihood, _ = super().load_mcmc_samples(
144147
mcmc_samples=mcmc_samples
145148
)
146-
149+
data_indices = torch.arange(self.train_X.shape[-1] - 1)
150+
data_indices[self.task_feature :] += 1 # exclude task feature
151+
data_covar_module.active_dims = data_indices
147152
latent_covar_module = MaternKernel(
148153
nu=2.5,
149154
ard_num_dims=self.task_rank,
150155
batch_shape=batch_shape,
151156
).to(**tkwargs)
157+
152158
latent_covar_module.lengthscale = reshape_and_detach(
153159
target=latent_covar_module.lengthscale,
154160
new_value=mcmc_samples["task_lengthscale"],
@@ -159,6 +165,7 @@ def load_mcmc_samples(
159165
num_tasks=self.num_tasks,
160166
rank=self.task_rank,
161167
batch_shape=latent_features.shape[:-2],
168+
active_dims=[self.task_feature],
162169
).to(**tkwargs)
163170
task_covar_module.covar_factor = Parameter(
164171
task_covar.cholesky().to_dense().detach()
@@ -169,12 +176,12 @@ def load_mcmc_samples(
169176
# affects predictions in practice, but setting it to zero is consistent with the
170177
# previous implementation.
171178
task_covar_module.var = torch.zeros_like(task_covar_module.var)
172-
return mean_module, covar_module, likelihood, task_covar_module
179+
covar_module = data_covar_module * task_covar_module
180+
return mean_module, covar_module, likelihood, None
173181

174182

175183
class SaasFullyBayesianMultiTaskGP(MultiTaskGP):
176184
r"""A fully Bayesian multi-task GP model with the SAAS prior.
177-
178185
This model assumes that the inputs have been normalized to [0, 1]^d and that the
179186
output has been stratified standardized to have zero mean and unit variance for
180187
each task. The SAAS model [Eriksson2021saasbo]_ with a Matern-5/2 is used as data
@@ -286,8 +293,6 @@ def __init__(
286293
self.mean_module = None
287294
self.covar_module = None
288295
self.likelihood = None
289-
self.task_covar_module = None
290-
self.register_buffer("latent_features", None)
291296
if pyro_model is None:
292297
pyro_model = MultitaskSaasPyroModel()
293298
pyro_model.set_inputs(
@@ -321,21 +326,20 @@ def train(
321326
self.mean_module = None
322327
self.covar_module = None
323328
self.likelihood = None
324-
self.task_covar_module = None
325329
return self
326330

327331
@property
328332
def median_lengthscale(self) -> Tensor:
329333
r"""Median lengthscales across the MCMC samples."""
330334
self._check_if_fitted()
331-
lengthscale = self.covar_module.base_kernel.lengthscale.clone()
335+
lengthscale = self.covar_module.kernels[0].base_kernel.lengthscale.clone()
332336
return lengthscale.median(0).values.squeeze(0)
333337

334338
@property
335339
def num_mcmc_samples(self) -> int:
336340
r"""Number of MCMC samples in the model."""
337341
self._check_if_fitted()
338-
return len(self.covar_module.outputscale)
342+
return self.covar_module.kernels[0].batch_shape[0]
339343

340344
@property
341345
def batch_shape(self) -> torch.Size:
@@ -367,7 +371,7 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None:
367371
self.mean_module,
368372
self.covar_module,
369373
self.likelihood,
370-
self.task_covar_module,
374+
_,
371375
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
372376

373377
def posterior(
@@ -438,7 +442,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
438442
self.mean_module,
439443
self.covar_module,
440444
self.likelihood,
441-
self.task_covar_module,
445+
_, # Possibly space for input transform
442446
) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples)
443447
# Load the actual samples from the state dict
444448
super().load_state_dict(state_dict=state_dict, strict=strict)

botorch/models/multitask.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def __init__(
172172
X=train_X, input_transform=input_transform
173173
)
174174
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
175+
176+
# IndexKernel cannot work with negative task features, so we shift them to
177+
# be positive here.
178+
if task_feature < 0:
179+
task_feature += transformed_X.shape[-1]
175180
(
176181
all_tasks_inferred,
177182
task_feature,
@@ -220,16 +225,29 @@ def __init__(
220225
)
221226
self.mean_module = mean_module or ConstantMean()
222227
if covar_module is None:
223-
self.covar_module = get_covar_module_with_dim_scaled_prior(
224-
ard_num_dims=self.num_non_task_features
228+
data_covar_module = get_covar_module_with_dim_scaled_prior(
229+
ard_num_dims=self.num_non_task_features,
230+
active_dims=self._base_idxr,
225231
)
226232
else:
227-
self.covar_module = covar_module
233+
data_covar_module = covar_module
234+
# This check enables models which don't adhere to the convention (e.g.
235+
# adding additional feature dimensions, like HeteroMTGP) to be used.
236+
if covar_module.active_dims is None:
237+
# Since we no longer use the custom indexing which derived the
238+
# task indexing in the forward pass, we need to explicitly set
239+
# the active dims here to ensure that the forward pass works.
240+
data_covar_module.active_dims = self._base_idxr
228241

229242
self._rank = rank if rank is not None else self.num_tasks
230-
self.task_covar_module = IndexKernel(
231-
num_tasks=self.num_tasks, rank=self._rank, prior=task_covar_prior
243+
task_covar_module = IndexKernel(
244+
num_tasks=self.num_tasks,
245+
rank=self._rank,
246+
prior=task_covar_prior,
247+
active_dims=[task_feature],
232248
)
249+
250+
self.covar_module = data_covar_module * task_covar_module
233251
task_mapper = get_task_value_remapping(
234252
task_values=torch.tensor(
235253
all_tasks, dtype=torch.long, device=train_X.device
@@ -244,45 +262,40 @@ def __init__(
244262
self.outcome_transform = outcome_transform
245263
self.to(train_X)
246264

247-
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor]:
248-
r"""Extracts base features and task indices from input data.
265+
def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
266+
r"""Extracts features before task feature, task indices, and features after task feature.
249267
250268
Args:
251269
x: The full input tensor with trailing dimension of size `d + 1`.
252270
Should be of float/double data type.
253271
254272
Returns:
255-
2-element tuple containing
256-
257-
- A `q x d` or `b x q x d` (batch mode) tensor with trailing
258-
dimension made up of the `d` non-task-index columns of `x`, arranged
259-
in the order as specified by the indexer generated during model
260-
instantiation.
261-
- A `q` or `b x q` (batch mode) tensor of long data type containing
262-
the task indices.
273+
3-element tuple containing
274+
275+
- A `q x d` or `b x q x d` tensor with features before the task feature
276+
- A `q` or `b x q` tensor with mapped task indices
277+
- A `q x d` or `b x q x d` tensor with features after the task feature
263278
"""
264-
batch_shape, d = x.shape[:-2], x.shape[-1]
265-
x_basic = x[..., self._base_idxr].view(batch_shape + torch.Size([-1, d - 1]))
266-
task_idcs = (
267-
x[..., self._task_feature]
268-
.view(batch_shape + torch.Size([-1, 1]))
269-
.to(dtype=torch.long)
270-
)
271-
task_idcs = self._map_tasks(task_values=task_idcs)
272-
return x_basic, task_idcs
279+
batch_shape = x.shape[:-2]
280+
# Extract task indices and convert to long
281+
task_idcs = x[..., self._task_feature].view(batch_shape + torch.Size([-1, 1]))
282+
task_idcs = self._map_tasks(task_values=task_idcs.to(dtype=torch.long))
283+
284+
# Extract features before and after task feature
285+
x_before = x[..., : self._task_feature]
286+
x_after = x[..., (self._task_feature + 1) :]
287+
return x_before, task_idcs, x_after
273288

274289
def forward(self, x: Tensor) -> MultivariateNormal:
275290
if self.training:
276291
x = self.transform_inputs(x)
277-
x_basic, task_idcs = self._split_inputs(x)
278-
# Compute base mean and covariance
279-
mean_x = self.mean_module(x_basic)
280-
covar_x = self.covar_module(x_basic)
281-
# Compute task covariances
282-
covar_i = self.task_covar_module(task_idcs)
283-
# Combine the two in an ICM fashion
284-
covar = covar_x.mul(covar_i)
285-
return MultivariateNormal(mean_x, covar)
292+
293+
# Get features before task feature, task indices, and features after task feature
294+
# split features applies the feature mapping (and is thus not a no-op)
295+
x = torch.cat(self._split_inputs(x), dim=-1)
296+
mean_x = self.mean_module(x)
297+
covar_x = self.covar_module(x)
298+
return MultivariateNormal(mean_x, covar_x)
286299

287300
@classmethod
288301
def get_all_tasks(

0 commit comments

Comments
 (0)