Skip to content

Commit 9460560

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Preserving train inputs and targets through transforms (#3044)
Summary: This PR preserves botorch transforms (specifically outcome_transforms, like Standardize) through state_dict loading. The fix also ensures that train_targets of a Leave-one-out model with outcome transforms will, in the default case, have the same targets as a base model, minus the point left out. __Longer explanation:__ Transforms, and specifically learnable output transforms like Standardize, will currently: a. Learn the parameters at initialization of the GP b. Transform the train_Ys to the normalized space Then, when we load a state dict, we will: a. Impose new standardization parameters on already standardized data b. Potentially make the transforms re-learnable, nullifying the change made by the state dict This has undesired consequences for cross-validation, as all cross-validated models will effectively have different training data. In essence, _we don't simply leave one point out, but instead we leave one out and re-standardize_. When we have outliers in the data, this will lead to substantially different predictions when the outlier is left out, since the outlier will substantially impact the outcome transform parameters. Notebook explaining the effect with some plots: N8342965 Reviewed By: Balandat Differential Revision: D84571407
1 parent f122efc commit 9460560

File tree

3 files changed

+328
-2
lines changed

3 files changed

+328
-2
lines changed

botorch/models/gpytorch.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import warnings
1818
from abc import ABC
1919
from copy import deepcopy
20-
from typing import Any, TYPE_CHECKING
20+
from typing import Any, Mapping, TYPE_CHECKING
2121

2222
import torch
2323
from botorch.acquisition.objective import PosteriorTransform
@@ -45,7 +45,10 @@
4545
from botorch.utils.multitask import separate_mtmvn
4646
from botorch.utils.transforms import is_ensemble
4747
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
48-
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
48+
from gpytorch.likelihoods.gaussian_likelihood import (
49+
_GaussianLikelihoodBase,
50+
FixedNoiseGaussianLikelihood,
51+
)
4952
from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator
5053
from torch import broadcast_shapes, Tensor
5154

@@ -283,6 +286,107 @@ def condition_on_observations(
283286
).detach()
284287
return fantasy_model
285288

289+
def load_state_dict(
290+
self,
291+
state_dict: Mapping[str, Any],
292+
strict: bool = True,
293+
keep_transforms: bool = True,
294+
) -> None:
295+
r"""Load the model state.
296+
297+
Args:
298+
state_dict: A dict containing the state of the model.
299+
strict: A boolean indicating whether to strictly enforce that the keys.
300+
keep_transforms: A boolean indicating whether to keep the input and outcome
301+
transforms. Doing so is useful when loading a model that was trained on
302+
a full set of data, and is later loaded with a subset of the data.
303+
"""
304+
# If `keep_transforms is false, the transforms are reset to the default values
305+
# and re-trained when the model is evaluated, which may lead to different
306+
# behavior than when the initial model was trained, pre-loading.
307+
if not keep_transforms:
308+
super().load_state_dict(state_dict, strict)
309+
return
310+
311+
# Checks that
312+
# 1. the model has train targets (not necessarily true, e.g. for ApproximateGP),
313+
# 2. The model accepts a transform, and that it is is not None.
314+
should_outcome_transform = (
315+
hasattr(self, "train_targets")
316+
and getattr(self, "outcome_transform", None) is not None
317+
)
318+
with torch.no_grad():
319+
retransformed_Y = None
320+
untransformed_Yvar = None
321+
# This means that
322+
if self.num_outputs > 1:
323+
untransformed_Y = self.train_targets.transpose(-1, -2)
324+
if isinstance(self.likelihood, _GaussianLikelihoodBase):
325+
untransformed_Yvar = self.likelihood.noise_covar.noise.transpose(
326+
-1, -2
327+
)
328+
else:
329+
untransformed_Y = self.train_targets.unsqueeze(-1)
330+
if isinstance(self.likelihood, _GaussianLikelihoodBase):
331+
untransformed_Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1)
332+
333+
# NOTE Some outcome transforms require an X, but the untransformed X's cannot
334+
# generally be extracted without transformations & adding batch dimensions,
335+
# e.g. in Warp). Thus, we use the train inputs.
336+
X = self.train_inputs[0]
337+
338+
# We obtain the untransformed Y (the train_Y's) by untransforming the train
339+
# targets.
340+
if should_outcome_transform:
341+
try:
342+
untransformed_Y, untransformed_Yvar = (
343+
self.outcome_transform.untransform(
344+
Y=untransformed_Y,
345+
Yvar=untransformed_Yvar,
346+
X=X,
347+
)
348+
)
349+
except NotImplementedError:
350+
# If the outcome transform does not support untransforming, we
351+
# re-transform the train targets.
352+
warnings.warn(
353+
"Outcome transform does not support untransforming."
354+
"Cannot load the state dict with transforms preserved."
355+
"Setting keep_transforms=False.",
356+
stacklevel=3,
357+
)
358+
super().load_state_dict(state_dict, strict)
359+
return
360+
361+
super().load_state_dict(state_dict, strict)
362+
363+
# If we want to keep the transforms, we cannot have them in train mode.
364+
# If we do, the transforms will be re-trained when the model is evaluated.
365+
if getattr(self, "input_transform", None) is not None:
366+
self.input_transform.eval()
367+
368+
# Now, the outcome transform is identical to the state_dict'ed model, so we may
369+
# once again transform the train targets.
370+
if should_outcome_transform:
371+
self.outcome_transform.eval()
372+
retransformed_Y, retransformed_Yvar = self.outcome_transform(
373+
Y=untransformed_Y, Yvar=untransformed_Yvar, X=X
374+
)
375+
376+
# not all models have self._transform_tensor_args, so we do this instead.
377+
if self.num_outputs > 1:
378+
retransformed_Y = retransformed_Y.transpose(-1, -2)
379+
retransformed_Yvar = retransformed_Yvar.transpose(-1, -2)
380+
else:
381+
retransformed_Y = retransformed_Y.squeeze(-1)
382+
retransformed_Yvar = retransformed_Yvar.squeeze(-1)
383+
self.set_train_data(
384+
targets=retransformed_Y,
385+
strict=strict,
386+
)
387+
if isinstance(self.likelihood, _GaussianLikelihoodBase):
388+
self.likelihood.noise_covar.noise = retransformed_Yvar
389+
286390

287391
# pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape,
288392
# _aug_batch_shape

botorch/models/likelihoods/sparse_outlier_noise.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,3 +479,11 @@ def _optimal_rhos(self, mll: ExactMarginalLogLikelihood) -> Tensor:
479479
loo_error = loo_mean - Y
480480
optimal_rho_deltas = loo_error.square() - loo_var
481481
return (optimal_rho_deltas - self.rho).clamp(0)[~self.is_active]
482+
483+
@property
484+
def noise(self) -> Tensor:
485+
return self.base_noise.noise
486+
487+
@noise.setter
488+
def noise(self, value: Tensor) -> None:
489+
self.base_noise.initialize(noise=value)

test/models/test_model.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@
99
from botorch.acquisition.objective import PosteriorTransform
1010
from botorch.exceptions.errors import InputDataError
1111
from botorch.models.deterministic import GenericDeterministicModel
12+
from botorch.models.gp_regression import SingleTaskGP
1213
from botorch.models.model import Model, ModelDict, ModelList
14+
from botorch.models.transforms.input import Normalize, Round
15+
from botorch.models.transforms.outcome import Standardize
1316
from botorch.posteriors.ensemble import EnsemblePosterior
1417
from botorch.posteriors.posterior_list import PosteriorList
1518
from botorch.utils.datasets import SupervisedDataset
1619
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
1720
from torch import rand
21+
from torch.nn import Module
22+
23+
24+
class NonUntransformableOutcomeTransform(Standardize):
25+
def untransform(self, **kwargs):
26+
raise NotImplementedError
1827

1928

2029
class NotSoAbstractBaseModel(Model):
@@ -138,6 +147,211 @@ def test_posterior_transform(self):
138147
)
139148

140149

150+
def _get_input_output_transform(
151+
d: int, m: int, use_transforms: bool = True
152+
) -> dict[str, Module]:
153+
return {
154+
"input_transform": Normalize(d=d) if use_transforms else None,
155+
"outcome_transform": Standardize(m=m) if use_transforms else None,
156+
}
157+
158+
159+
class TestTransformWarnings(BotorchTestCase):
160+
def test_set_transformed_inputs_warning_no_train_inputs(self):
161+
"""Test warning when model has input_transform but no train_inputs."""
162+
# Setup: Create a model with input_transform but without train_inputs attribute
163+
model = NotSoAbstractBaseModel()
164+
model.input_transform = Normalize(d=2)
165+
166+
# Execute: Call _set_transformed_inputs which should trigger warning
167+
# Assert: Verify warning is raised
168+
with self.assertWarnsRegex(
169+
RuntimeWarning,
170+
"Could not update `train_inputs` with transformed inputs "
171+
"since NotSoAbstractBaseModel does not have a `train_inputs` "
172+
"attribute. Make sure that the `input_transform` is applied to "
173+
"both the train inputs and test inputs.",
174+
):
175+
model._set_transformed_inputs()
176+
177+
def test_load_state_dict_output_warnings(self):
178+
"""Test warning when outcome transform doesn't support untransforming."""
179+
tkwargs = {"device": self.device, "dtype": torch.double}
180+
181+
train_X = torch.rand(3, 2, **tkwargs)
182+
train_Y = torch.rand(3, 1, **tkwargs)
183+
184+
# Setup: Create model with untransformable outcome transform
185+
model = SingleTaskGP(
186+
train_X=train_X,
187+
train_Y=train_Y,
188+
input_transform=Normalize(d=2),
189+
outcome_transform=NonUntransformableOutcomeTransform(m=1),
190+
)
191+
state_dict = model.state_dict()
192+
193+
# Assert: Verify warning is raised for untransformable outcome transform
194+
with self.assertWarnsRegex(
195+
UserWarning,
196+
"Outcome transform does not support untransforming.*",
197+
):
198+
model.load_state_dict(state_dict, keep_transforms=True)
199+
200+
201+
class TestLoadStateDict(BotorchTestCase):
202+
def _test_load_state_dict_base(
203+
self, num_outputs: int, include_yvar: bool = True
204+
) -> None:
205+
"""Base test helper for load_state_dict with transforms."""
206+
tkwargs = {"device": self.device, "dtype": torch.double}
207+
from botorch.models.gp_regression import SingleTaskGP
208+
209+
train_X = torch.rand(3, 2, **tkwargs)
210+
train_X = torch.cat(
211+
[train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0
212+
)
213+
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).repeat(1, num_outputs)
214+
215+
model_kwargs = {
216+
"train_X": train_X,
217+
"train_Y": train_Y,
218+
}
219+
220+
if include_yvar:
221+
train_Yvar = 0.1 * torch.rand_like(train_Y)
222+
model_kwargs["train_Yvar"] = train_Yvar
223+
224+
base_model = SingleTaskGP(
225+
**model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
226+
)
227+
228+
original_train_inputs = base_model.input_transform(base_model.train_inputs[0])
229+
original_train_targets = base_model.train_targets.clone()
230+
original_train_yvar = base_model.likelihood.noise_covar.noise.clone()
231+
232+
state_dict = base_model.state_dict()
233+
234+
cv_model_kwargs = model_kwargs.copy()
235+
cv_model_kwargs["train_X"] = train_X[:-1]
236+
cv_model_kwargs["train_Y"] = train_Y[:-1]
237+
if include_yvar:
238+
cv_model_kwargs["train_Yvar"] = train_Yvar[:-1]
239+
cv_model = SingleTaskGP(
240+
**cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
241+
)
242+
243+
# Test keep_transforms=True
244+
cv_model.load_state_dict(state_dict, keep_transforms=True)
245+
246+
# Ensure outcome transform is in eval mode and doesn't change parameters
247+
sd_mean = cv_model.outcome_transform.means
248+
cv_model.outcome_transform(train_Y[:-1])
249+
self.assertTrue(torch.all(cv_model.outcome_transform.means == sd_mean))
250+
251+
# Check that transform parameters match state_dict
252+
self.assertTrue(
253+
torch.allclose(
254+
cv_model.input_transform._offset,
255+
state_dict["input_transform._offset"],
256+
)
257+
)
258+
self.assertTrue(
259+
torch.allclose(
260+
cv_model.outcome_transform.means,
261+
state_dict["outcome_transform.means"],
262+
)
263+
)
264+
265+
# Verify train data preservation in transformed space
266+
self.assertAllClose(cv_model.train_targets, original_train_targets[..., :-1])
267+
self.assertTrue(
268+
torch.equal(
269+
cv_model.input_transform(cv_model.train_inputs[0]),
270+
original_train_inputs[..., :-1, :],
271+
)
272+
)
273+
if include_yvar:
274+
self.assertAllClose(
275+
cv_model.likelihood.noise_covar.noise, original_train_yvar[..., :-1]
276+
)
277+
278+
# Test keep_transforms=False (allows refitting)
279+
cv_model = SingleTaskGP(
280+
**cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
281+
)
282+
cv_model.load_state_dict(state_dict, keep_transforms=False)
283+
284+
# Transforms should refit on new data
285+
sd_mean = cv_model.outcome_transform.means
286+
cv_model.outcome_transform(train_Y[:-1])
287+
self.assertTrue(torch.all(cv_model.outcome_transform.means != sd_mean))
288+
289+
self.assertFalse(
290+
torch.equal(
291+
cv_model.input_transform(cv_model.train_inputs[0]),
292+
original_train_inputs[..., :-1, :],
293+
)
294+
)
295+
self.assertFalse(
296+
torch.equal(cv_model.train_targets, original_train_targets[..., :-1])
297+
)
298+
self.assertFalse(
299+
torch.equal(
300+
cv_model.input_transform._offset,
301+
state_dict["input_transform._offset"],
302+
)
303+
)
304+
self.assertFalse(
305+
torch.equal(
306+
cv_model.outcome_transform.means,
307+
state_dict["outcome_transform.means"],
308+
)
309+
)
310+
311+
def test_load_state_dict_with_transforms(self):
312+
"""Test load_state_dict with input and outcome transforms."""
313+
self._test_load_state_dict_base(num_outputs=1, include_yvar=True)
314+
315+
def test_load_state_dict_with_transforms_no_yvar(self):
316+
"""Test load_state_dict with input and outcome transforms without Yvar."""
317+
self._test_load_state_dict_base(num_outputs=1, include_yvar=False)
318+
319+
def test_load_state_dict_multi_output_with_transforms(self):
320+
"""Test load_state_dict with multi-output model and transforms."""
321+
self._test_load_state_dict_base(num_outputs=3, include_yvar=True)
322+
323+
def test_load_state_dict_multi_output_with_transforms_no_yvar(self):
324+
"""Test load_state_dict with multi-output model and transforms without Yvar."""
325+
self._test_load_state_dict_base(num_outputs=3, include_yvar=False)
326+
327+
def test_load_state_dict_no_transforms(self):
328+
"""Test load_state_dict without any transforms."""
329+
tkwargs = {"device": self.device, "dtype": torch.double}
330+
from botorch.models.gp_regression import SingleTaskGP
331+
332+
train_X = torch.rand(3, 2, **tkwargs)
333+
train_X = torch.cat(
334+
[train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0
335+
)
336+
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
337+
338+
base_model = SingleTaskGP(
339+
train_X=train_X, train_Y=train_Y, outcome_transform=None
340+
)
341+
original_train_targets = base_model.train_targets.clone()
342+
state_dict = base_model.state_dict()
343+
344+
cv_model = SingleTaskGP(
345+
train_X=train_X[:-1], train_Y=train_Y[:-1], outcome_transform=None
346+
)
347+
cv_model.load_state_dict(state_dict, keep_transforms=False)
348+
349+
# Verify train targets are preserved
350+
self.assertTrue(
351+
torch.equal(cv_model.train_targets, original_train_targets[:-1])
352+
)
353+
354+
141355
class TestModelDict(BotorchTestCase):
142356
def test_model_dict(self):
143357
models = {"m1": MockModel(MockPosterior()), "m2": MockModel(MockPosterior())}

0 commit comments

Comments
 (0)