Skip to content

[Bug] Standardize modules can't be loaded once trained #1874

@samuelstanton

Description

@samuelstanton

🐛 Bug

Standardize initializes its buffers to None which causes Standardize.load_state_dict to fail if the state dict has values that are not None. Setting strict=False also fails.

To reproduce

** Code snippet to reproduce **

import torch
from botorch.models.transforms.outcome import Standardize

transform = Standardize(m=1)
targets = torch.rand(2, 1)
transform.train()
transform(targets)
state_dict = transform.state_dict()
print(type(transform.means))

new_transform = Standardize(m=1)
print(type(new_transform.means))
new_transform.load_state_dict(state_dict)

** Stack trace/error message **

RuntimeError                              Traceback (most recent call last)
Cell In [8], line 16
     14 new_transform = Standardize(m=1)
     15 print(type(new_transform.means))
---> 16 new_transform.load_state_dict(state_dict)

File ~/scratch/miniconda/envs/cortex-env/lib/python3.9/site-packages/torch/nn/modules/module.py:1667, in Module.load_state_dict(self, state_dict, strict)
   1662         error_msgs.insert(
   1663             0, 'Missing key(s) in state_dict: {}. '.format(
   1664                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1666 if len(error_msgs) > 0:
-> 1667     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1668                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1669 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Standardize:
	Unexpected key(s) in state_dict: "means", "stdvs", "_stdvs_sq".

Expected Behavior

If I have a model checkpoint with a Standardize submodule, I expect torch.Module.load_state_dict to work without issues.

System information

Please complete the following information:

  • BoTorch Version: 0.8.3
  • GPyTorch Version: 1.9.1
  • PyTorch Version: 1.13.0+cu117
  • Computer OS: Red Hat Enterprise Linux 8.6

Potential Fix

class LoadableStandardize(Standardize):
    def __init__(
        self,
        m: int,
        outputs: Optional[List[int]] = None,
        batch_shape: torch.Size = torch.Size(),  # noqa: B008
        min_stdv: float = 1e-8,
    ) -> None:
        r"""Standardize outcomes (zero mean, unit variance).

        Args:
            m: The output dimension.
            outputs: Which of the outputs to standardize. If omitted, all
                outputs will be standardized.
            batch_shape: The batch_shape of the training targets.
            min_stddv: The minimum standard deviation for which to perform
                standardization (if lower, only de-mean the data).
        """
        super().__init__(m=m)
        self.register_buffer("means", torch.zeros(*batch_shape, m))
        self.register_buffer("stdvs", torch.ones(*batch_shape, m))
        self.register_buffer("_stdvs_sq", torch.ones(*batch_shape, m))
        self._outputs = normalize_indices(outputs, d=m)
        self._m = m
        self._batch_shape = batch_shape
        self._min_stdv = min_stdv

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions