@@ -115,7 +115,6 @@ def __init__(
115115 all_tasks : list [int ] | None = None ,
116116 outcome_transform : OutcomeTransform | _DefaultType | None = DEFAULT ,
117117 input_transform : InputTransform | None = None ,
118- validate_task_values : bool = True ,
119118 ) -> None :
120119 r"""Multi-Task GP model using an ICM kernel.
121120
@@ -158,9 +157,6 @@ def __init__(
158157 instantiation of the model.
159158 input_transform: An input transform that is applied in the model's
160159 forward pass.
161- validate_task_values: If True, validate that the task values supplied in the
162- input are expected tasks values. If false, unexpected task values
163- will be mapped to the first output_task if supplied.
164160
165161 Example:
166162 >>> X1, X2 = torch.rand(10, 2), torch.rand(20, 2)
@@ -193,7 +189,7 @@ def __init__(
193189 "This is not allowed as it will lead to errors during model training."
194190 )
195191 all_tasks = all_tasks or all_tasks_inferred
196- self .num_tasks = len (all_tasks_inferred )
192+ self .num_tasks = len (all_tasks )
197193 if outcome_transform == DEFAULT :
198194 outcome_transform = Standardize (m = 1 , batch_shape = train_X .shape [:- 2 ])
199195 if outcome_transform is not None :
@@ -263,61 +259,19 @@ def __init__(
263259
264260 self .covar_module = data_covar_module * task_covar_module
265261 task_mapper = get_task_value_remapping (
266- observed_task_values = torch .tensor (
267- all_tasks_inferred , dtype = torch .long , device = train_X .device
268- ),
269- all_task_values = torch .tensor (
270- sorted (all_tasks ), dtype = torch .long , device = train_X .device
262+ task_values = torch .tensor (
263+ all_tasks , dtype = torch .long , device = train_X .device
271264 ),
272265 dtype = train_X .dtype ,
273- default_task_value = None if output_tasks is None else output_tasks [0 ],
274266 )
275267 self .register_buffer ("_task_mapper" , task_mapper )
276- self ._expected_task_values = set (all_tasks_inferred )
268+ self ._expected_task_values = set (all_tasks )
277269 if input_transform is not None :
278270 self .input_transform = input_transform
279271 if outcome_transform is not None :
280272 self .outcome_transform = outcome_transform
281- self ._validate_task_values = validate_task_values
282273 self .to (train_X )
283274
284- def _map_tasks (self , task_values : Tensor ) -> Tensor :
285- """Map raw task values to the task indices used by the model.
286-
287- Args:
288- task_values: A tensor of task values.
289-
290- Returns:
291- A tensor of task indices with the same shape as the input
292- tensor.
293- """
294- long_task_values = task_values .long ()
295- if self ._validate_task_values :
296- if self ._task_mapper is None :
297- if not (
298- torch .all (0 <= task_values )
299- and torch .all (task_values < self .num_tasks )
300- ):
301- raise ValueError (
302- "Expected all task features in `X` to be between 0 and "
303- f"self.num_tasks - 1. Got { task_values } ."
304- )
305- else :
306- unexpected_task_values = set (
307- long_task_values .unique ().tolist ()
308- ).difference (self ._expected_task_values )
309- if len (unexpected_task_values ) > 0 :
310- raise ValueError (
311- "Received invalid raw task values. Expected raw value to be in"
312- f" { self ._expected_task_values } , but got unexpected task"
313- f" values: { unexpected_task_values } ."
314- )
315- task_values = self ._task_mapper [long_task_values ]
316- elif self ._task_mapper is not None :
317- task_values = self ._task_mapper [long_task_values ]
318-
319- return task_values
320-
321275 def _split_inputs (self , x : Tensor ) -> tuple [Tensor , Tensor , Tensor ]:
322276 r"""Extracts features before task feature, task indices, and features after
323277 the task feature.
@@ -330,7 +284,7 @@ def _split_inputs(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
330284 3-element tuple containing
331285
332286 - A `q x d` or `b x q x d` tensor with features before the task feature
333- - A `q` or `b x q x 1 ` tensor with mapped task indices
287+ - A `q` or `b x q` tensor with mapped task indices
334288 - A `q x d` or `b x q x d` tensor with features after the task feature
335289 """
336290 batch_shape = x .shape [:- 2 ]
@@ -370,7 +324,7 @@ def get_all_tasks(
370324 raise ValueError (f"Must have that -{ d } <= task_feature <= { d } " )
371325 task_feature = task_feature % (d + 1 )
372326 all_tasks = (
373- train_X [..., task_feature ].to (dtype = torch .long ). unique ( sorted = True ).tolist ()
327+ train_X [..., task_feature ].unique ( sorted = True ). to (dtype = torch .long ).tolist ()
374328 )
375329 return all_tasks , task_feature , d
376330
0 commit comments