@@ -115,6 +115,7 @@ def __init__(
115115 all_tasks : list [int ] | None = None ,
116116 outcome_transform : OutcomeTransform | _DefaultType | None = DEFAULT ,
117117 input_transform : InputTransform | None = None ,
118+ validate_task_values : bool = True ,
118119 ) -> None :
119120 r"""Multi-Task GP model using an ICM kernel.
120121
@@ -157,6 +158,9 @@ def __init__(
157158 instantiation of the model.
158159 input_transform: An input transform that is applied in the model's
159160 forward pass.
161+ validate_task_values: If True, validate that the task values supplied in the
162+ input are expected tasks values. If false, unexpected task values
163+ will be mapped to the first output_task if supplied.
160164
161165 Example:
162166 >>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -189,7 +193,7 @@ def __init__(
189193 "This is not allowed as it will lead to errors during model training."
190194 )
191195 all_tasks = all_tasks or all_tasks_inferred
192- self .num_tasks = len (all_tasks )
196+ self .num_tasks = len (all_tasks_inferred )
193197 if outcome_transform == DEFAULT :
194198 outcome_transform = Standardize (m = 1 , batch_shape = train_X .shape [:- 2 ])
195199 if outcome_transform is not None :
@@ -249,19 +253,61 @@ def __init__(
249253
250254 self .covar_module = data_covar_module * task_covar_module
251255 task_mapper = get_task_value_remapping (
252- task_values = torch .tensor (
253- all_tasks , dtype = torch .long , device = train_X .device
256+ observed_task_values = torch .tensor (
257+ all_tasks_inferred , dtype = torch .long , device = train_X .device
258+ ),
259+ all_task_values = torch .tensor (
260+ sorted (all_tasks ), dtype = torch .long , device = train_X .device
254261 ),
255262 dtype = train_X .dtype ,
263+ default_task_value = None if output_tasks is None else output_tasks [0 ],
256264 )
257265 self .register_buffer ("_task_mapper" , task_mapper )
258- self ._expected_task_values = set (all_tasks )
266+ self ._expected_task_values = set (all_tasks_inferred )
259267 if input_transform is not None :
260268 self .input_transform = input_transform
261269 if outcome_transform is not None :
262270 self .outcome_transform = outcome_transform
271+ self ._validate_task_values = validate_task_values
263272 self .to (train_X )
264273
274+ def _map_tasks (self , task_values : Tensor ) -> Tensor :
275+ """Map raw task values to the task indices used by the model.
276+
277+ Args:
278+ task_values: A tensor of task values.
279+
280+ Returns:
281+ A tensor of task indices with the same shape as the input
282+ tensor.
283+ """
284+ long_task_values = task_values .long ()
285+ if self ._validate_task_values :
286+ if self ._task_mapper is None :
287+ if not (
288+ torch .all (0 <= task_values )
289+ and torch .all (task_values < self .num_tasks )
290+ ):
291+ raise ValueError (
292+ "Expected all task features in `X` to be between 0 and "
293+ f"self.num_tasks - 1. Got { task_values } ."
294+ )
295+ else :
296+ unexpected_task_values = set (
297+ long_task_values .unique ().tolist ()
298+ ).difference (self ._expected_task_values )
299+ if len (unexpected_task_values ) > 0 :
300+ raise ValueError (
301+ "Received invalid raw task values. Expected raw value to be in"
302+ f" { self ._expected_task_values } , but got unexpected task"
303+ f" values: { unexpected_task_values } ."
304+ )
305+ task_values = self ._task_mapper [long_task_values ]
306+ elif self ._task_mapper is not None :
307+ task_values = self ._task_mapper [long_task_values ]
308+
309+ return task_values
310+
265311 def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
266312 r"""Extracts features before task feature, task indices, and features after
267313 the task feature.
@@ -274,7 +320,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
274320 3-element tuple containing
275321
276322 - A `q x d` or `b x q x d` tensor with features before the task feature
277- - A `q` or `b x q` tensor with mapped task indices
323+ - A `q` or `b x q x 1 ` tensor with mapped task indices
278324 - A `q x d` or `b x q x d` tensor with features after the task feature
279325 """
280326 batch_shape = x .shape [:- 2 ]
@@ -314,7 +360,7 @@ def get_all_tasks(
314360 raise ValueError (f"Must have that -{ d } <= task_feature <= { d } " )
315361 task_feature = task_feature % (d + 1 )
316362 all_tasks = (
317- train_X [..., task_feature ].unique ( sorted = True ). to (dtype = torch .long ).tolist ()
363+ train_X [..., task_feature ].to (dtype = torch .long ). unique ( sorted = True ).tolist ()
318364 )
319365 return all_tasks , task_feature , d
320366
0 commit comments