@@ -172,6 +172,11 @@ def __init__(
172172 X = train_X , input_transform = input_transform
173173 )
174174 self ._validate_tensor_args (X = transformed_X , Y = train_Y , Yvar = train_Yvar )
175+
176+ # IndexKernel cannot work with negative task features, so we shift them to
177+ # be positive here.
178+ if task_feature < 0 :
179+ task_feature += transformed_X .shape [- 1 ]
175180 (
176181 all_tasks_inferred ,
177182 task_feature ,
@@ -220,16 +225,29 @@ def __init__(
220225 )
221226 self .mean_module = mean_module or ConstantMean ()
222227 if covar_module is None :
223- self .covar_module = get_covar_module_with_dim_scaled_prior (
224- ard_num_dims = self .num_non_task_features
228+ data_covar_module = get_covar_module_with_dim_scaled_prior (
229+ ard_num_dims = self .num_non_task_features ,
230+ active_dims = self ._base_idxr ,
225231 )
226232 else :
227- self .covar_module = covar_module
233+ data_covar_module = covar_module
234+ # This check enables models which don't adhere to the convention (e.g.
235+ # adding additional feature dimensions, like HeteroMTGP) to be used.
236+ if covar_module .active_dims is None :
237+ # Since we no longer use the custom indexing which derived the
238+ # task indexing in the forward pass, we need to explicitly set
239+ # the active dims here to ensure that the forward pass works.
240+ data_covar_module .active_dims = self ._base_idxr
228241
229242 self ._rank = rank if rank is not None else self .num_tasks
230- self .task_covar_module = IndexKernel (
231- num_tasks = self .num_tasks , rank = self ._rank , prior = task_covar_prior
243+ task_covar_module = IndexKernel (
244+ num_tasks = self .num_tasks ,
245+ rank = self ._rank ,
246+ prior = task_covar_prior ,
247+ active_dims = [task_feature ],
232248 )
249+
250+ self .covar_module = data_covar_module * task_covar_module
233251 task_mapper = get_task_value_remapping (
234252 task_values = torch .tensor (
235253 all_tasks , dtype = torch .long , device = train_X .device
@@ -244,45 +262,40 @@ def __init__(
244262 self .outcome_transform = outcome_transform
245263 self .to (train_X )
246264
247- def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor ]:
248- r"""Extracts base features and task indices from input data .
265+ def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
266+ r"""Extracts features before task feature, task indices, and features after task feature .
249267
250268 Args:
251269 x: The full input tensor with trailing dimension of size `d + 1`.
252270 Should be of float/double data type.
253271
254272 Returns:
255- 2-element tuple containing
256-
257- - A `q x d` or `b x q x d` (batch mode) tensor with trailing
258- dimension made up of the `d` non-task-index columns of `x`, arranged
259- in the order as specified by the indexer generated during model
260- instantiation.
261- - A `q` or `b x q` (batch mode) tensor of long data type containing
262- the task indices.
273+ 3-element tuple containing
274+
275+ - A `q x d` or `b x q x d` tensor with features before the task feature
276+ - A `q` or `b x q` tensor with mapped task indices
277+ - A `q x d` or `b x q x d` tensor with features after the task feature
263278 """
264- batch_shape , d = x .shape [:- 2 ], x . shape [ - 1 ]
265- x_basic = x [..., self . _base_idxr ]. view ( batch_shape + torch . Size ([ - 1 , d - 1 ]))
266- task_idcs = (
267- x [..., self . _task_feature ]
268- . view ( batch_shape + torch . Size ([ - 1 , 1 ]))
269- . to ( dtype = torch . long )
270- )
271- task_idcs = self . _map_tasks ( task_values = task_idcs )
272- return x_basic , task_idcs
279+ batch_shape = x .shape [:- 2 ]
280+ # Extract task indices and convert to long
281+ task_idcs = x [..., self . _task_feature ]. view ( batch_shape + torch . Size ([ - 1 , 1 ]))
282+ task_idcs = self . _map_tasks ( task_values = task_idcs . to ( dtype = torch . long ))
283+
284+ # Extract features before and after task feature
285+ x_before = x [..., : self . _task_feature ]
286+ x_after = x [..., ( self . _task_feature + 1 ) :]
287+ return x_before , task_idcs , x_after
273288
274289 def forward (self , x : Tensor ) -> MultivariateNormal :
275290 if self .training :
276291 x = self .transform_inputs (x )
277- x_basic , task_idcs = self ._split_inputs (x )
278- # Compute base mean and covariance
279- mean_x = self .mean_module (x_basic )
280- covar_x = self .covar_module (x_basic )
281- # Compute task covariances
282- covar_i = self .task_covar_module (task_idcs )
283- # Combine the two in an ICM fashion
284- covar = covar_x .mul (covar_i )
285- return MultivariateNormal (mean_x , covar )
292+
293+ # Get features before task feature, task indices, and features after task feature
294+ # split features applies the feature mapping (and is thus not a no-op)
295+ x = torch .cat (self ._split_inputs (x ), dim = - 1 )
296+ mean_x = self .mean_module (x )
297+ covar_x = self .covar_module (x )
298+ return MultivariateNormal (mean_x , covar_x )
286299
287300 @classmethod
288301 def get_all_tasks (
0 commit comments