-
Notifications
You must be signed in to change notification settings - Fork 451
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Bug
Standardize initializes its buffers to None which causes Standardize.load_state_dict to fail if the state dict has values that are not None. Setting strict=False also fails.
To reproduce
** Code snippet to reproduce **
import torch
from botorch.models.transforms.outcome import Standardize
transform = Standardize(m=1)
targets = torch.rand(2, 1)
transform.train()
transform(targets)
state_dict = transform.state_dict()
print(type(transform.means))
new_transform = Standardize(m=1)
print(type(new_transform.means))
new_transform.load_state_dict(state_dict)** Stack trace/error message **
RuntimeError Traceback (most recent call last)
Cell In [8], line 16
14 new_transform = Standardize(m=1)
15 print(type(new_transform.means))
---> 16 new_transform.load_state_dict(state_dict)
File ~/scratch/miniconda/envs/cortex-env/lib/python3.9/site-packages/torch/nn/modules/module.py:1667, in Module.load_state_dict(self, state_dict, strict)
1662 error_msgs.insert(
1663 0, 'Missing key(s) in state_dict: {}. '.format(
1664 ', '.join('"{}"'.format(k) for k in missing_keys)))
1666 if len(error_msgs) > 0:
-> 1667 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1668 self.__class__.__name__, "\n\t".join(error_msgs)))
1669 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for Standardize:
Unexpected key(s) in state_dict: "means", "stdvs", "_stdvs_sq".
Expected Behavior
If I have a model checkpoint with a Standardize submodule, I expect torch.Module.load_state_dict to work without issues.
System information
Please complete the following information:
- BoTorch Version: 0.8.3
- GPyTorch Version: 1.9.1
- PyTorch Version: 1.13.0+cu117
- Computer OS: Red Hat Enterprise Linux 8.6
Potential Fix
class LoadableStandardize(Standardize):
def __init__(
self,
m: int,
outputs: Optional[List[int]] = None,
batch_shape: torch.Size = torch.Size(), # noqa: B008
min_stdv: float = 1e-8,
) -> None:
r"""Standardize outcomes (zero mean, unit variance).
Args:
m: The output dimension.
outputs: Which of the outputs to standardize. If omitted, all
outputs will be standardized.
batch_shape: The batch_shape of the training targets.
min_stddv: The minimum standard deviation for which to perform
standardization (if lower, only de-mean the data).
"""
super().__init__(m=m)
self.register_buffer("means", torch.zeros(*batch_shape, m))
self.register_buffer("stdvs", torch.ones(*batch_shape, m))
self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, m))
self._outputs = normalize_indices(outputs, d=m)
self._m = m
self._batch_shape = batch_shape
self._min_stdv = min_stdv
Balandat and SebastianAment
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working