Skip to content

Commit a4163d0

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Preserving train inputs and targets through transforms (#3044)
Summary: This PR preserves botorch transforms (specifically outcome_transforms, like Standardize) through state_dict loading. The fix also ensures that train_targets of a Leave-one-out model with outcome transforms will, in the default case, have the same targets as a base model, minus the point left out. __Longer explanation:__ Transforms, and specifically learnable output transforms like Standardize, will currently: a. Learn the parameters at initialization of the GP b. Transform the train_Ys to the normalized space Then, when we load a state dict, we will: a. Impose new standardization parameters on already standardized data b. Potentially make the transforms re-learnable, nullifying the change made by the state dict This has undesired consequences for cross-validation, as all cross-validated models will effectively have different training data. In essence, _we don't simply leave one point out, but instead we leave one out and re-standardize_. When we have outliers in the data, this will lead to substantially different predictions when the outlier is left out, since the outlier will substantially impact the outcome transform parameters. Notebook explaining the effect with some plots: N8342965 Reviewed By: Balandat Differential Revision: D84571407
1 parent 09502f9 commit a4163d0

File tree

2 files changed

+306
-1
lines changed

2 files changed

+306
-1
lines changed

botorch/models/gpytorch.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from abc import ABC
1919
from copy import deepcopy
20-
from typing import Any, TYPE_CHECKING
20+
from typing import Any, Mapping, TYPE_CHECKING
2121

2222
import torch
2323
from botorch.acquisition.objective import PosteriorTransform
@@ -283,6 +283,115 @@ def condition_on_observations(
283283
).detach()
284284
return fantasy_model
285285

286+
def load_state_dict(
287+
self,
288+
state_dict: Mapping[str, Any],
289+
strict: bool = True,
290+
keep_transforms: bool = True,
291+
) -> None:
292+
r"""Load the model state.
293+
294+
Args:
295+
state_dict: A dict containing the state of the model.
296+
strict: A boolean indicating whether to strictly enforce that the keys.
297+
keep_transforms: A boolean indicating whether to keep the input and outcome
298+
transforms. Doing so is useful when loading a model that was trained on
299+
a full set of data, and is later loaded with a subset of the data.
300+
"""
301+
# If `keep_transforms is false, the transforms are reset to the default values
302+
# and re-trained when the model is evaluated, which may lead to different
303+
# behavior than when the initial model was trained, pre-loading.
304+
if not keep_transforms:
305+
super().load_state_dict(state_dict, strict)
306+
return
307+
308+
# Checks that
309+
# 1. the model has train targets (not necessarily true, e.g. for ApproximateGP),
310+
# 2. The model accepts a transform, and that it is is not None.
311+
should_outcome_transform = (
312+
hasattr(self, "train_targets")
313+
and getattr(self, "outcome_transform", None) is not None
314+
)
315+
with torch.no_grad():
316+
untransformed_Yvar = None
317+
# This becomes necessary when we have model batch_shapes,
318+
# e.g. in FullyBayesianSingleTaskGP/MultiTaskGP. Then, we have a
319+
# batch dimension in the noise, but not in the train_targets.
320+
# Thus, we get this nested structure of if-statements to ensure
321+
# train_targets and Yvar is on shape [batch_shape] x n x m,
322+
# with batch_shape included only if the training data initially
323+
# contained it.
324+
if self.num_outputs > 1 and not isinstance(self, MultiTaskGPyTorchModel):
325+
untransformed_Y = self.train_targets.transpose(-1, -2)
326+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
327+
untransformed_Yvar = self.likelihood.noise_covar.noise.transpose(
328+
-1, -2
329+
)
330+
else:
331+
untransformed_Y = self.train_targets.unsqueeze(-1)
332+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
333+
untransformed_Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1)
334+
335+
# NOTE Some outcome transforms require an X, but the untransformed X's cannot
336+
# generally be extracted without transformations & adding batch dimensions,
337+
# e.g. in Warp). Thus, we use the train inputs.
338+
X = self.train_inputs[0]
339+
340+
# We obtain the untransformed Y (the train_Y's) by untransforming the train
341+
# targets.
342+
if should_outcome_transform:
343+
try:
344+
untransformed_Y, untransformed_Yvar = (
345+
self.outcome_transform.untransform(
346+
Y=untransformed_Y,
347+
Yvar=untransformed_Yvar,
348+
X=X,
349+
)
350+
)
351+
except NotImplementedError:
352+
# If the outcome transform does not support untransforming, we
353+
# re-transform the train targets.
354+
warnings.warn(
355+
"Outcome transform does not support untransforming."
356+
"Cannot load the state dict with transforms preserved."
357+
"Setting keep_transforms=False.",
358+
stacklevel=3,
359+
)
360+
super().load_state_dict(state_dict, strict)
361+
return
362+
363+
super().load_state_dict(state_dict, strict)
364+
365+
# If we want to keep the transforms, we cannot have them in train mode.
366+
# If we do, the transforms will be re-trained when the model is evaluated.
367+
if getattr(self, "input_transform", None) is not None:
368+
self.input_transform.eval()
369+
370+
# Now, the outcome transform is identical to the state_dict'ed model, so we may
371+
# once again transform the train targets.
372+
if should_outcome_transform:
373+
self.outcome_transform.eval()
374+
retransformed_Y, retransformed_Yvar = self.outcome_transform(
375+
Y=untransformed_Y, Yvar=untransformed_Yvar, X=X
376+
)
377+
378+
# not all models have self._transform_tensor_args, so we do this instead.
379+
if self.num_outputs > 1 and not isinstance(self, MultiTaskGPyTorchModel):
380+
retransformed_Y = retransformed_Y.transpose(-1, -2)
381+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
382+
retransformed_Yvar = retransformed_Yvar.transpose(-1, -2)
383+
self.likelihood.noise_covar.noise = retransformed_Yvar
384+
else:
385+
retransformed_Y = retransformed_Y.squeeze(-1)
386+
if isinstance(self.likelihood, FixedNoiseGaussianLikelihood):
387+
retransformed_Yvar = retransformed_Yvar.squeeze(-1)
388+
self.likelihood.noise_covar.noise = retransformed_Yvar
389+
390+
self.set_train_data(
391+
targets=retransformed_Y,
392+
strict=strict,
393+
)
394+
286395

287396
# pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape,
288397
# _aug_batch_shape

test/models/test_gpytorch.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from botorch.exceptions.errors import DeprecationError, InputDataError
1818
from botorch.exceptions.warnings import InputDataWarning
1919
from botorch.fit import fit_gpytorch_mll
20+
from botorch.models.gp_regression import SingleTaskGP
2021
from botorch.models.gpytorch import (
2122
BatchedMultiOutputGPyTorchModel,
2223
GPyTorchModel,
@@ -28,6 +29,7 @@
2829
from botorch.models.transforms.input import (
2930
ChainedInputTransform,
3031
InputTransform,
32+
Normalize,
3133
NumericToCategoricalEncoding,
3234
)
3335
from botorch.models.utils import fantasize
@@ -870,3 +872,197 @@ def test_condition_on_observations_train_input_shapes(self):
870872
fantasy_model._original_train_inputs.shape[0], original_size + 1
871873
)
872874
self.assertEqual(model2._original_train_inputs.shape[0], original_size)
875+
876+
877+
class NonUntransformableOutcomeTransform(Standardize):
878+
def untransform(self, **kwargs):
879+
raise NotImplementedError
880+
881+
882+
def _get_input_output_transform(
883+
d: int, m: int, use_transforms: bool = True
884+
) -> dict[str, torch.nn.Module]:
885+
return {
886+
"input_transform": Normalize(d=d) if use_transforms else None,
887+
"outcome_transform": Standardize(m=m) if use_transforms else None,
888+
}
889+
890+
891+
class TestTransformWarnings(BotorchTestCase):
892+
def test_set_transformed_inputs_warning_no_train_inputs(self):
893+
from botorch.models.model import Model
894+
895+
class NotSoAbstractBaseModel(Model):
896+
def posterior(self, X, output_indices, observation_noise, **kwargs):
897+
pass
898+
899+
model = NotSoAbstractBaseModel()
900+
model.input_transform = Normalize(d=2)
901+
902+
with self.assertWarnsRegex(
903+
RuntimeWarning,
904+
"Could not update `train_inputs` with transformed inputs "
905+
"since NotSoAbstractBaseModel does not have a `train_inputs` "
906+
"attribute. Make sure that the `input_transform` is applied to "
907+
"both the train inputs and test inputs.",
908+
):
909+
model._set_transformed_inputs()
910+
911+
def test_load_state_dict_output_warnings(self):
912+
tkwargs = {"device": self.device, "dtype": torch.double}
913+
914+
train_X = torch.rand(3, 2, **tkwargs)
915+
train_Y = torch.rand(3, 1, **tkwargs)
916+
917+
model = SingleTaskGP(
918+
train_X=train_X,
919+
train_Y=train_Y,
920+
input_transform=Normalize(d=2),
921+
outcome_transform=NonUntransformableOutcomeTransform(m=1),
922+
)
923+
state_dict = model.state_dict()
924+
925+
with self.assertWarnsRegex(
926+
UserWarning,
927+
"Outcome transform does not support untransforming.*",
928+
):
929+
model.load_state_dict(state_dict, keep_transforms=True)
930+
931+
932+
class TestLoadStateDict(BotorchTestCase):
933+
def _test_load_state_dict_base(
934+
self, num_outputs: int, include_yvar: bool = True
935+
) -> None:
936+
tkwargs = {"device": self.device, "dtype": torch.double}
937+
938+
train_X = torch.rand(3, 2, **tkwargs)
939+
train_X = torch.cat(
940+
[train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0
941+
)
942+
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).repeat(1, num_outputs)
943+
944+
model_kwargs = {
945+
"train_X": train_X,
946+
"train_Y": train_Y,
947+
}
948+
949+
if include_yvar:
950+
train_Yvar = 0.1 * torch.rand_like(train_Y)
951+
model_kwargs["train_Yvar"] = train_Yvar
952+
953+
base_model = SingleTaskGP(
954+
**model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
955+
)
956+
957+
original_train_inputs = base_model.input_transform(base_model.train_inputs[0])
958+
original_train_targets = base_model.train_targets.clone()
959+
original_train_yvar = base_model.likelihood.noise_covar.noise.clone()
960+
961+
state_dict = base_model.state_dict()
962+
963+
cv_model_kwargs = model_kwargs.copy()
964+
cv_model_kwargs["train_X"] = train_X[:-1]
965+
cv_model_kwargs["train_Y"] = train_Y[:-1]
966+
if include_yvar:
967+
cv_model_kwargs["train_Yvar"] = train_Yvar[:-1]
968+
cv_model = SingleTaskGP(
969+
**cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
970+
)
971+
972+
cv_model.load_state_dict(state_dict, keep_transforms=True)
973+
974+
sd_mean = cv_model.outcome_transform.means
975+
cv_model.outcome_transform(train_Y[:-1])
976+
self.assertTrue(torch.all(cv_model.outcome_transform.means == sd_mean))
977+
978+
self.assertTrue(
979+
torch.allclose(
980+
cv_model.input_transform._offset,
981+
state_dict["input_transform._offset"],
982+
)
983+
)
984+
self.assertTrue(
985+
torch.allclose(
986+
cv_model.outcome_transform.means,
987+
state_dict["outcome_transform.means"],
988+
)
989+
)
990+
991+
self.assertAllClose(cv_model.train_targets, original_train_targets[..., :-1])
992+
self.assertTrue(
993+
torch.equal(
994+
cv_model.input_transform(cv_model.train_inputs[0]),
995+
original_train_inputs[..., :-1, :],
996+
)
997+
)
998+
if include_yvar:
999+
self.assertAllClose(
1000+
cv_model.likelihood.noise_covar.noise, original_train_yvar[..., :-1]
1001+
)
1002+
1003+
cv_model = SingleTaskGP(
1004+
**cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
1005+
)
1006+
cv_model.load_state_dict(state_dict, keep_transforms=False)
1007+
1008+
sd_mean = cv_model.outcome_transform.means
1009+
cv_model.outcome_transform(train_Y[:-1])
1010+
self.assertTrue(torch.all(cv_model.outcome_transform.means != sd_mean))
1011+
1012+
self.assertFalse(
1013+
torch.equal(
1014+
cv_model.input_transform(cv_model.train_inputs[0]),
1015+
original_train_inputs[..., :-1, :],
1016+
)
1017+
)
1018+
self.assertFalse(
1019+
torch.equal(cv_model.train_targets, original_train_targets[..., :-1])
1020+
)
1021+
self.assertFalse(
1022+
torch.equal(
1023+
cv_model.input_transform._offset,
1024+
state_dict["input_transform._offset"],
1025+
)
1026+
)
1027+
self.assertFalse(
1028+
torch.equal(
1029+
cv_model.outcome_transform.means,
1030+
state_dict["outcome_transform.means"],
1031+
)
1032+
)
1033+
1034+
def test_load_state_dict_with_transforms(self):
1035+
self._test_load_state_dict_base(num_outputs=1, include_yvar=True)
1036+
1037+
def test_load_state_dict_with_transforms_no_yvar(self):
1038+
self._test_load_state_dict_base(num_outputs=1, include_yvar=False)
1039+
1040+
def test_load_state_dict_multi_output_with_transforms(self):
1041+
self._test_load_state_dict_base(num_outputs=3, include_yvar=True)
1042+
1043+
def test_load_state_dict_multi_output_with_transforms_no_yvar(self):
1044+
self._test_load_state_dict_base(num_outputs=3, include_yvar=False)
1045+
1046+
def test_load_state_dict_no_transforms(self):
1047+
tkwargs = {"device": self.device, "dtype": torch.double}
1048+
1049+
train_X = torch.rand(3, 2, **tkwargs)
1050+
train_X = torch.cat(
1051+
[train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0
1052+
)
1053+
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
1054+
1055+
base_model = SingleTaskGP(
1056+
train_X=train_X, train_Y=train_Y, outcome_transform=None
1057+
)
1058+
original_train_targets = base_model.train_targets.clone()
1059+
state_dict = base_model.state_dict()
1060+
1061+
cv_model = SingleTaskGP(
1062+
train_X=train_X[:-1], train_Y=train_Y[:-1], outcome_transform=None
1063+
)
1064+
cv_model.load_state_dict(state_dict, keep_transforms=False)
1065+
1066+
self.assertTrue(
1067+
torch.equal(cv_model.train_targets, original_train_targets[:-1])
1068+
)

0 commit comments

Comments
 (0)