Skip to content

Commit e17f003

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Preserving train inputs and targets through transforms (#3044)
Summary: This PR preserves botorch transforms (specifically outcome_transforms, like Standardize) through state_dict loading. The fix also ensures that train_targets of a Leave-one-out model with outcome transforms will, in the default case, have the same targets as a base model, minus the point left out. __Longer explanation:__ Transforms, and specifically learnable output transforms like Standardize, will currently: a. Learn the parameters at initialization of the GP b. Transform the train_Ys to the normalized space Then, when we load a state dict, we will: a. Impose new standardization parameters on already standardized data b. Potentially make the transforms re-learnable, nullifying the change made by the state dict This has undesired consequences for cross-validation, as all cross-validated models will effectively have different training data. In essence, _we don't simply leave one point out, but instead we leave one out and re-standardize_. When we have outliers in the data, this will lead to substantially different predictions when the outlier is left out, since the outlier will substantially impact the outcome transform parameters. TODO: - Account for non-invertible transforms Differential Revision: D84571407
1 parent f122efc commit e17f003

File tree

2 files changed

+348
-1
lines changed

2 files changed

+348
-1
lines changed

botorch/models/model.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
from botorch.utils.containers import BotorchContainer
3535
from botorch.utils.datasets import SupervisedDataset
3636
from botorch.utils.transforms import is_fully_bayesian
37-
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
37+
from gpytorch.likelihoods.gaussian_likelihood import (
38+
_GaussianLikelihoodBase,
39+
FixedNoiseGaussianLikelihood,
40+
)
3841
from torch import Tensor
3942
from torch.nn import Module, ModuleDict, ModuleList
4043
from typing_extensions import Self
@@ -268,6 +271,112 @@ def train(self, mode: bool = True) -> Model:
268271
def dtypes_of_buffers(self) -> set[torch.dtype]:
269272
return {t.dtype for t in self.buffers() if t is not None}
270273

274+
def load_state_dict(
275+
self,
276+
state_dict: Mapping[str, Any],
277+
strict: bool = True,
278+
keep_transforms: bool = True,
279+
) -> None:
280+
r"""Load the model state.
281+
282+
Args:
283+
state_dict: A dict containing the state of the model.
284+
strict: A boolean indicating whether to strictly enforce that the keys.
285+
keep_transforms: A boolean indicating whether to keep the input and outcome
286+
transforms. Doing so is useful when loading a model that was trained on
287+
a full set of data, and is later loaded with a subset of the data. If
288+
`keep_transforms=False`, the transforms are reset to the default values
289+
and re-trained when the model is evaluated, which may lead to different
290+
behavior than when the initial model was trained, pre-loading. Yvar does
291+
not need to be transformed, since it is saved as part of the state dict
292+
and will thus always be on the right format.
293+
"""
294+
if not keep_transforms:
295+
super().load_state_dict(state_dict, strict)
296+
return
297+
should_input_transform = (
298+
hasattr(self, "input_transform") and self.input_transform is not None
299+
)
300+
should_outcome_transform = (
301+
hasattr(self, "outcome_transform") and self.outcome_transform is not None
302+
)
303+
with torch.no_grad():
304+
retransformed_Y = None
305+
untransformed_Yvar = None
306+
if self._num_outputs > 1:
307+
untransformed_Y = self.train_targets.transpose(-1, -2)
308+
if isinstance(self.likelihood, _GaussianLikelihoodBase):
309+
untransformed_Yvar = self.likelihood.noise_covar.noise.transpose(
310+
-1, -2
311+
)
312+
else:
313+
untransformed_Y = self.train_targets.unsqueeze(-1)
314+
if isinstance(self.likelihood, _GaussianLikelihoodBase):
315+
untransformed_Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1)
316+
317+
untransformed_X = self.train_inputs[0]
318+
if should_input_transform:
319+
try:
320+
untransformed_X = self.input_transform.untransform(untransformed_X)
321+
except NotImplementedError:
322+
# If the outcome transform does not support untransforming, we
323+
# re-transform the train targets.
324+
warnings.warn(
325+
"Input transform does not support untransforming. Cannot load"
326+
"the state dict with input transforms preserved. If the outcome"
327+
"transform requires the inputs to be computed, it will be"
328+
"computed on the transformed inputs."
329+
)
330+
331+
# We obtain the untransformed Y (the train_Y's) by untransforming the train
332+
# targets.
333+
if should_outcome_transform:
334+
try:
335+
untransformed_Y, untransformed_Yvar = (
336+
self.outcome_transform.untransform(
337+
Y=untransformed_Y,
338+
Yvar=untransformed_Yvar,
339+
X=untransformed_X,
340+
)
341+
)
342+
except NotImplementedError:
343+
# If the outcome transform does not support untransforming, we
344+
# re-transform the train targets.
345+
warnings.warn(
346+
"Outcome transform does not support untransforming. Cannot load the"
347+
"state dict with transforms preserved. Setting keep_transforms=False.",
348+
)
349+
super().load_state_dict(state_dict, strict)
350+
return
351+
352+
super().load_state_dict(state_dict, strict)
353+
# If we want to keep the transforms, we cannot have them in train mode when the
354+
# state dict is loaded.
355+
if should_input_transform:
356+
self.input_transform.eval()
357+
358+
# Now, the outcome transform is identical to the state_dict'ed model, so we may
359+
# once again transform the train targets.
360+
if should_outcome_transform:
361+
self.outcome_transform.eval()
362+
retransformed_Y, retransformed_Yvar = self.outcome_transform(
363+
Y=untransformed_Y, Yvar=untransformed_Yvar, X=untransformed_X
364+
)
365+
366+
# not all models have self._transform_tensor_args, so we do this instead.
367+
if self._num_outputs > 1:
368+
retransformed_Y = retransformed_Y.transpose(-1, -2)
369+
retransformed_Yvar = retransformed_Yvar.transpose(-1, -2)
370+
else:
371+
retransformed_Y = retransformed_Y.squeeze(-1)
372+
retransformed_Yvar = retransformed_Yvar.squeeze(-1)
373+
self.set_train_data(
374+
targets=retransformed_Y,
375+
strict=strict,
376+
)
377+
if isinstance(self.likelihood, _GaussianLikelihoodBase):
378+
self.likelihood.noise_covar.noise = retransformed_Yvar
379+
271380

272381
class FantasizeMixin(ABC):
273382
"""

test/models/test_model.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,21 @@
99
from botorch.acquisition.objective import PosteriorTransform
1010
from botorch.exceptions.errors import InputDataError
1111
from botorch.models.deterministic import GenericDeterministicModel
12+
from botorch.models.gp_regression import SingleTaskGP
1213
from botorch.models.model import Model, ModelDict, ModelList
14+
from botorch.models.transforms.input import Normalize, Round
15+
from botorch.models.transforms.outcome import Standardize
1316
from botorch.posteriors.ensemble import EnsemblePosterior
1417
from botorch.posteriors.posterior_list import PosteriorList
1518
from botorch.utils.datasets import SupervisedDataset
1619
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
1720
from torch import rand
21+
from torch.nn import Module
22+
23+
24+
class NonUntransformableOutcomeTransform(Standardize):
25+
def untransform(self, **kwargs):
26+
raise NotImplementedError
1827

1928

2029
class NotSoAbstractBaseModel(Model):
@@ -138,6 +147,235 @@ def test_posterior_transform(self):
138147
)
139148

140149

150+
def _get_input_output_transform(
151+
d: int, m: int, use_transforms: bool = True
152+
) -> dict[str, Module]:
153+
return {
154+
"input_transform": Normalize(d=d) if use_transforms else None,
155+
"outcome_transform": Standardize(m=m) if use_transforms else None,
156+
}
157+
158+
159+
class TestTransformWarnings(BotorchTestCase):
160+
def test_set_transformed_inputs_warning_no_train_inputs(self):
161+
"""Test warning when model has input_transform but no train_inputs."""
162+
# Setup: Create a model with input_transform but without train_inputs attribute
163+
model = NotSoAbstractBaseModel()
164+
model.input_transform = Normalize(d=2)
165+
166+
# Execute: Call _set_transformed_inputs which should trigger warning
167+
# Assert: Verify warning is raised
168+
with self.assertWarnsRegex(
169+
RuntimeWarning,
170+
"Could not update `train_inputs` with transformed inputs "
171+
"since NotSoAbstractBaseModel does not have a `train_inputs` "
172+
"attribute. Make sure that the `input_transform` is applied to "
173+
"both the train inputs and test inputs.",
174+
):
175+
model._set_transformed_inputs()
176+
177+
def test_load_state_dict_input_warnings(self):
178+
"""Test warning when input transform doesn't support untransforming."""
179+
tkwargs = {"device": self.device, "dtype": torch.double}
180+
181+
train_X = torch.rand(3, 2, **tkwargs)
182+
train_Y = torch.rand(3, 1, **tkwargs)
183+
184+
# Setup: Create model with untransformable input transform
185+
model = SingleTaskGP(
186+
train_X=train_X,
187+
train_Y=train_Y,
188+
input_transform=Round(),
189+
outcome_transform=Standardize(m=1),
190+
)
191+
state_dict = model.state_dict()
192+
193+
# Execute: Load state dict with keep_transforms=True
194+
# Assert: Verify warning is raised for untransformable input transform
195+
with self.assertWarnsRegex(
196+
UserWarning,
197+
"Input transform does not support untransforming.*",
198+
):
199+
model.load_state_dict(state_dict, keep_transforms=True)
200+
201+
def test_load_state_dict_output_warnings(self):
202+
"""Test warning when outcome transform doesn't support untransforming."""
203+
tkwargs = {"device": self.device, "dtype": torch.double}
204+
205+
train_X = torch.rand(3, 2, **tkwargs)
206+
train_Y = torch.rand(3, 1, **tkwargs)
207+
208+
# Setup: Create model with untransformable outcome transform
209+
model = SingleTaskGP(
210+
train_X=train_X,
211+
train_Y=train_Y,
212+
input_transform=Normalize(d=2),
213+
outcome_transform=NonUntransformableOutcomeTransform(m=1),
214+
)
215+
state_dict = model.state_dict()
216+
217+
# Assert: Verify warning is raised for untransformable outcome transform
218+
with self.assertWarnsRegex(
219+
UserWarning,
220+
"Outcome transform does not support untransforming.*",
221+
):
222+
model.load_state_dict(state_dict, keep_transforms=True)
223+
224+
225+
class TestLoadStateDict(BotorchTestCase):
226+
def _test_load_state_dict_base(
227+
self, num_outputs: int, include_yvar: bool = True
228+
) -> None:
229+
"""Base test helper for load_state_dict with transforms."""
230+
tkwargs = {"device": self.device, "dtype": torch.double}
231+
from botorch.models.gp_regression import SingleTaskGP
232+
233+
train_X = torch.rand(3, 2, **tkwargs)
234+
train_X = torch.cat(
235+
[train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0
236+
)
237+
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).repeat(1, num_outputs)
238+
239+
model_kwargs = {
240+
"train_X": train_X,
241+
"train_Y": train_Y,
242+
}
243+
244+
if include_yvar:
245+
train_Yvar = 0.1 * torch.rand_like(train_Y)
246+
model_kwargs["train_Yvar"] = train_Yvar
247+
248+
base_model = SingleTaskGP(
249+
**model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
250+
)
251+
252+
original_train_inputs = base_model.input_transform(base_model.train_inputs[0])
253+
original_train_targets = base_model.train_targets.clone()
254+
original_train_yvar = base_model.likelihood.noise_covar.noise.clone()
255+
256+
state_dict = base_model.state_dict()
257+
258+
cv_model_kwargs = model_kwargs.copy()
259+
cv_model_kwargs["train_X"] = train_X[:-1]
260+
cv_model_kwargs["train_Y"] = train_Y[:-1]
261+
if include_yvar:
262+
cv_model_kwargs["train_Yvar"] = train_Yvar[:-1]
263+
cv_model = SingleTaskGP(
264+
**cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
265+
)
266+
267+
# Test keep_transforms=True
268+
cv_model.load_state_dict(state_dict, keep_transforms=True)
269+
270+
# Ensure outcome transform is in eval mode and doesn't change parameters
271+
sd_mean = cv_model.outcome_transform.means
272+
cv_model.outcome_transform(train_Y[:-1])
273+
self.assertTrue(torch.all(cv_model.outcome_transform.means == sd_mean))
274+
275+
# Check that transform parameters match state_dict
276+
self.assertTrue(
277+
torch.allclose(
278+
cv_model.input_transform._offset,
279+
state_dict["input_transform._offset"],
280+
)
281+
)
282+
self.assertTrue(
283+
torch.allclose(
284+
cv_model.outcome_transform.means,
285+
state_dict["outcome_transform.means"],
286+
)
287+
)
288+
289+
# Verify train data preservation in transformed space
290+
self.assertAllClose(cv_model.train_targets, original_train_targets[..., :-1])
291+
self.assertTrue(
292+
torch.equal(
293+
cv_model.input_transform(cv_model.train_inputs[0]),
294+
original_train_inputs[..., :-1, :],
295+
)
296+
)
297+
if include_yvar:
298+
self.assertAllClose(
299+
cv_model.likelihood.noise_covar.noise, original_train_yvar[..., :-1]
300+
)
301+
302+
# Test keep_transforms=False (allows refitting)
303+
cv_model = SingleTaskGP(
304+
**cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs)
305+
)
306+
cv_model.load_state_dict(state_dict, keep_transforms=False)
307+
308+
# Transforms should refit on new data
309+
sd_mean = cv_model.outcome_transform.means
310+
cv_model.outcome_transform(train_Y[:-1])
311+
self.assertTrue(torch.all(cv_model.outcome_transform.means != sd_mean))
312+
313+
self.assertFalse(
314+
torch.equal(
315+
cv_model.input_transform(cv_model.train_inputs[0]),
316+
original_train_inputs[..., :-1, :],
317+
)
318+
)
319+
self.assertFalse(
320+
torch.equal(cv_model.train_targets, original_train_targets[..., :-1])
321+
)
322+
self.assertFalse(
323+
torch.equal(
324+
cv_model.input_transform._offset,
325+
state_dict["input_transform._offset"],
326+
)
327+
)
328+
self.assertFalse(
329+
torch.equal(
330+
cv_model.outcome_transform.means,
331+
state_dict["outcome_transform.means"],
332+
)
333+
)
334+
335+
def test_load_state_dict_with_transforms(self):
336+
"""Test load_state_dict with input and outcome transforms."""
337+
self._test_load_state_dict_base(num_outputs=1, include_yvar=True)
338+
339+
def test_load_state_dict_with_transforms_no_yvar(self):
340+
"""Test load_state_dict with input and outcome transforms without Yvar."""
341+
self._test_load_state_dict_base(num_outputs=1, include_yvar=False)
342+
343+
def test_load_state_dict_multi_output_with_transforms(self):
344+
"""Test load_state_dict with multi-output model and transforms."""
345+
self._test_load_state_dict_base(num_outputs=3, include_yvar=True)
346+
347+
def test_load_state_dict_multi_output_with_transforms_no_yvar(self):
348+
"""Test load_state_dict with multi-output model and transforms without Yvar."""
349+
self._test_load_state_dict_base(num_outputs=3, include_yvar=False)
350+
351+
def test_load_state_dict_no_transforms(self):
352+
"""Test load_state_dict without any transforms."""
353+
tkwargs = {"device": self.device, "dtype": torch.double}
354+
from botorch.models.gp_regression import SingleTaskGP
355+
356+
train_X = torch.rand(3, 2, **tkwargs)
357+
train_X = torch.cat(
358+
[train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0
359+
)
360+
train_Y = torch.sin(train_X).sum(dim=1, keepdim=True)
361+
362+
base_model = SingleTaskGP(
363+
train_X=train_X, train_Y=train_Y, outcome_transform=None
364+
)
365+
original_train_targets = base_model.train_targets.clone()
366+
state_dict = base_model.state_dict()
367+
368+
cv_model = SingleTaskGP(
369+
train_X=train_X[:-1], train_Y=train_Y[:-1], outcome_transform=None
370+
)
371+
cv_model.load_state_dict(state_dict, keep_transforms=False)
372+
373+
# Verify train targets are preserved
374+
self.assertTrue(
375+
torch.equal(cv_model.train_targets, original_train_targets[:-1])
376+
)
377+
378+
141379
class TestModelDict(BotorchTestCase):
142380
def test_model_dict(self):
143381
models = {"m1": MockModel(MockPosterior()), "m2": MockModel(MockPosterior())}

0 commit comments

Comments
 (0)