Skip to content

Commit 50dc3ba

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Preserving train inputs and targets through transforms (#3044)
Summary: This PR preserves botorch transforms (specifically outcome_transforms, like Standardize) through state_dict loading. The fix also ensures that train_targets of a Leave-one-out model with outcome transforms will, in the default case, have the same targets as a base model, minus the point left out. __Longer explanation:__ Transforms, and specifically learnable output transforms like Standardize, will currently: a. Learn the parameters at initialization of the GP b. Transform the train_Ys to the normalized space Then, when we load a state dict, we will: a. Impose new standardization parameters on already standardized data b. Potentially make the transforms re-learnable, nullifying the change made by the state dict This has undesired consequences for cross-validation, as all cross-validated models will effectively have different training data. In essence, _we don't simply leave one point out, but instead we leave one out and re-standardize_. When we have outliers in the data, this will lead to substantially different predictions when the outlier is left out, since the outlier will substantially impact the outcome transform parameters. Notebook explaining the effect with some plots: N8342965 Reviewed By: Balandat Differential Revision: D84571407
1 parent 09502f9 commit 50dc3ba

File tree

2 files changed

+361
-1
lines changed

2 files changed

+361
-1
lines changed

botorch/models/gpytorch.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from abc import ABC
1919
from copy import deepcopy
20-
from typing import Any, TYPE_CHECKING
20+
from typing import Any, Mapping, TYPE_CHECKING
2121

2222
import torch
2323
from botorch.acquisition.objective import PosteriorTransform
@@ -283,6 +283,111 @@ def condition_on_observations(
283283
).detach()
284284
return fantasy_model
285285

286+
def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]:
287+
r"""Extract targets and noise variance in the correct shape.
288+
289+
Returns a tuple of (Y, Yvar) where Y and Yvar have shape
290+
[batch_shape] x n x m, with batch_shape included only if the
291+
training data initially contained it.
292+
"""
293+
if self.num_outputs > 1:
294+
Y = self.train_targets.transpose(-1, -2)
295+
Yvar = None
296+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
297+
Yvar = self.likelihood.noise_covar.noise.transpose(-1, -2)
298+
else:
299+
Y = self.train_targets.unsqueeze(-1)
300+
Yvar = None
301+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
302+
Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1)
303+
return Y, Yvar
304+
305+
def _restore_targets_and_noise(
306+
self, Y: Tensor, Yvar: Tensor | None, strict: bool
307+
) -> None:
308+
r"""Restore targets and noise variance to the model.
309+
310+
Args:
311+
Y: Targets tensor in shape [batch_shape] x n x m.
312+
Yvar: Optional noise variance tensor in shape [batch_shape] x n x m.
313+
strict: Whether to strictly enforce shape constraints.
314+
"""
315+
if self.num_outputs > 1:
316+
Y = Y.transpose(-1, -2)
317+
if Yvar is not None and isinstance(
318+
self.likelihood, FixedNoiseGaussianLikelihood
319+
):
320+
Yvar = Yvar.transpose(-1, -2)
321+
self.likelihood.noise_covar.noise = Yvar
322+
else:
323+
Y = Y.squeeze(-1)
324+
if Yvar is not None and isinstance(
325+
self.likelihood, FixedNoiseGaussianLikelihood
326+
):
327+
Yvar = Yvar.squeeze(-1)
328+
self.likelihood.noise_covar.noise = Yvar
329+
330+
self.set_train_data(targets=Y, strict=strict)
331+
332+
def load_state_dict(
333+
self,
334+
state_dict: Mapping[str, Any],
335+
strict: bool = True,
336+
keep_transforms: bool = True,
337+
) -> None:
338+
r"""Load the model state.
339+
340+
Args:
341+
state_dict: A dict containing the state of the model.
342+
strict: A boolean indicating whether to strictly enforce that the keys.
343+
keep_transforms: A boolean indicating whether to keep the input and outcome
344+
transforms. Doing so is useful when loading a model that was trained on
345+
a full set of data, and is later loaded with a subset of the data.
346+
"""
347+
if not keep_transforms:
348+
super().load_state_dict(state_dict, strict)
349+
return
350+
351+
should_outcome_transform = (
352+
hasattr(self, "train_targets")
353+
and getattr(self, "outcome_transform", None) is not None
354+
)
355+
356+
with torch.no_grad():
357+
untransformed_Y, untransformed_Yvar = self._extract_targets_and_noise()
358+
X = self.train_inputs[0]
359+
360+
if should_outcome_transform:
361+
try:
362+
untransformed_Y, untransformed_Yvar = (
363+
self.outcome_transform.untransform(
364+
Y=untransformed_Y,
365+
Yvar=untransformed_Yvar,
366+
X=X,
367+
)
368+
)
369+
except NotImplementedError:
370+
warnings.warn(
371+
"Outcome transform does not support untransforming."
372+
"Cannot load the state dict with transforms preserved."
373+
"Setting keep_transforms=False.",
374+
stacklevel=3,
375+
)
376+
super().load_state_dict(state_dict, strict)
377+
return
378+
379+
super().load_state_dict(state_dict, strict)
380+
381+
if getattr(self, "input_transform", None) is not None:
382+
self.input_transform.eval()
383+
384+
if should_outcome_transform:
385+
self.outcome_transform.eval()
386+
retransformed_Y, retransformed_Yvar = self.outcome_transform(
387+
Y=untransformed_Y, Yvar=untransformed_Yvar, X=X
388+
)
389+
self._restore_targets_and_noise(retransformed_Y, retransformed_Yvar, strict)
390+
286391

287392
# pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape,
288393
# _aug_batch_shape
@@ -803,6 +908,38 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC):
803908
"long-format" multi-task GP in the style of `MultiTaskGP`.
804909
"""
805910

911+
def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]:
912+
r"""Extract targets and noise variance for multi-task models.
913+
914+
Returns a tuple of (Y, Yvar) where Y and Yvar have shape
915+
[batch_shape] x n x m, with batch_shape included only if the
916+
training data initially contained it.
917+
"""
918+
Y = self.train_targets.unsqueeze(-1)
919+
Yvar = None
920+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
921+
Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1)
922+
return Y, Yvar
923+
924+
def _restore_targets_and_noise(
925+
self, Y: Tensor, Yvar: Tensor | None, strict: bool
926+
) -> None:
927+
r"""Restore targets and noise variance for multi-task models.
928+
929+
Args:
930+
Y: Targets tensor in shape [batch_shape] x n x m.
931+
Yvar: Optional noise variance tensor in shape [batch_shape] x n x m.
932+
strict: Whether to strictly enforce shape constraints.
933+
"""
934+
Y = Y.squeeze(-1)
935+
if Yvar is not None and isinstance(
936+
self.likelihood, FixedNoiseGaussianLikelihood
937+
):
938+
Yvar = Yvar.squeeze(-1)
939+
self.likelihood.noise_covar.noise = Yvar
940+
941+
self.set_train_data(targets=Y, strict=strict)
942+
806943
def _apply_noise(
807944
self,
808945
X: Tensor,

0 commit comments

Comments
 (0)