|
17 | 17 | from botorch.exceptions.errors import DeprecationError, InputDataError |
18 | 18 | from botorch.exceptions.warnings import InputDataWarning |
19 | 19 | from botorch.fit import fit_gpytorch_mll |
| 20 | +from botorch.models.gp_regression import SingleTaskGP |
20 | 21 | from botorch.models.gpytorch import ( |
21 | 22 | BatchedMultiOutputGPyTorchModel, |
22 | 23 | GPyTorchModel, |
|
28 | 29 | from botorch.models.transforms.input import ( |
29 | 30 | ChainedInputTransform, |
30 | 31 | InputTransform, |
| 32 | + Normalize, |
31 | 33 | NumericToCategoricalEncoding, |
32 | 34 | ) |
33 | 35 | from botorch.models.utils import fantasize |
@@ -870,3 +872,197 @@ def test_condition_on_observations_train_input_shapes(self): |
870 | 872 | fantasy_model._original_train_inputs.shape[0], original_size + 1 |
871 | 873 | ) |
872 | 874 | self.assertEqual(model2._original_train_inputs.shape[0], original_size) |
| 875 | + |
| 876 | + |
| 877 | +class NonUntransformableOutcomeTransform(Standardize): |
| 878 | + def untransform(self, **kwargs): |
| 879 | + raise NotImplementedError |
| 880 | + |
| 881 | + |
| 882 | +def _get_input_output_transform( |
| 883 | + d: int, m: int, use_transforms: bool = True |
| 884 | +) -> dict[str, torch.nn.Module]: |
| 885 | + return { |
| 886 | + "input_transform": Normalize(d=d) if use_transforms else None, |
| 887 | + "outcome_transform": Standardize(m=m) if use_transforms else None, |
| 888 | + } |
| 889 | + |
| 890 | + |
| 891 | +class TestTransformWarnings(BotorchTestCase): |
| 892 | + def test_set_transformed_inputs_warning_no_train_inputs(self): |
| 893 | + from botorch.models.model import Model |
| 894 | + |
| 895 | + class NotSoAbstractBaseModel(Model): |
| 896 | + def posterior(self, X, output_indices, observation_noise, **kwargs): |
| 897 | + pass |
| 898 | + |
| 899 | + model = NotSoAbstractBaseModel() |
| 900 | + model.input_transform = Normalize(d=2) |
| 901 | + |
| 902 | + with self.assertWarnsRegex( |
| 903 | + RuntimeWarning, |
| 904 | + "Could not update `train_inputs` with transformed inputs " |
| 905 | + "since NotSoAbstractBaseModel does not have a `train_inputs` " |
| 906 | + "attribute. Make sure that the `input_transform` is applied to " |
| 907 | + "both the train inputs and test inputs.", |
| 908 | + ): |
| 909 | + model._set_transformed_inputs() |
| 910 | + |
| 911 | + def test_load_state_dict_output_warnings(self): |
| 912 | + tkwargs = {"device": self.device, "dtype": torch.double} |
| 913 | + |
| 914 | + train_X = torch.rand(3, 2, **tkwargs) |
| 915 | + train_Y = torch.rand(3, 1, **tkwargs) |
| 916 | + |
| 917 | + model = SingleTaskGP( |
| 918 | + train_X=train_X, |
| 919 | + train_Y=train_Y, |
| 920 | + input_transform=Normalize(d=2), |
| 921 | + outcome_transform=NonUntransformableOutcomeTransform(m=1), |
| 922 | + ) |
| 923 | + state_dict = model.state_dict() |
| 924 | + |
| 925 | + with self.assertWarnsRegex( |
| 926 | + UserWarning, |
| 927 | + "Outcome transform does not support untransforming.*", |
| 928 | + ): |
| 929 | + model.load_state_dict(state_dict, keep_transforms=True) |
| 930 | + |
| 931 | + |
| 932 | +class TestLoadStateDict(BotorchTestCase): |
| 933 | + def _test_load_state_dict_base( |
| 934 | + self, num_outputs: int, include_yvar: bool = True |
| 935 | + ) -> None: |
| 936 | + tkwargs = {"device": self.device, "dtype": torch.double} |
| 937 | + |
| 938 | + train_X = torch.rand(3, 2, **tkwargs) |
| 939 | + train_X = torch.cat( |
| 940 | + [train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0 |
| 941 | + ) |
| 942 | + train_Y = torch.sin(train_X).sum(dim=1, keepdim=True).repeat(1, num_outputs) |
| 943 | + |
| 944 | + model_kwargs = { |
| 945 | + "train_X": train_X, |
| 946 | + "train_Y": train_Y, |
| 947 | + } |
| 948 | + |
| 949 | + if include_yvar: |
| 950 | + train_Yvar = 0.1 * torch.rand_like(train_Y) |
| 951 | + model_kwargs["train_Yvar"] = train_Yvar |
| 952 | + |
| 953 | + base_model = SingleTaskGP( |
| 954 | + **model_kwargs, **_get_input_output_transform(d=2, m=num_outputs) |
| 955 | + ) |
| 956 | + |
| 957 | + original_train_inputs = base_model.input_transform(base_model.train_inputs[0]) |
| 958 | + original_train_targets = base_model.train_targets.clone() |
| 959 | + original_train_yvar = base_model.likelihood.noise_covar.noise.clone() |
| 960 | + |
| 961 | + state_dict = base_model.state_dict() |
| 962 | + |
| 963 | + cv_model_kwargs = model_kwargs.copy() |
| 964 | + cv_model_kwargs["train_X"] = train_X[:-1] |
| 965 | + cv_model_kwargs["train_Y"] = train_Y[:-1] |
| 966 | + if include_yvar: |
| 967 | + cv_model_kwargs["train_Yvar"] = train_Yvar[:-1] |
| 968 | + cv_model = SingleTaskGP( |
| 969 | + **cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs) |
| 970 | + ) |
| 971 | + |
| 972 | + cv_model.load_state_dict(state_dict, keep_transforms=True) |
| 973 | + |
| 974 | + sd_mean = cv_model.outcome_transform.means |
| 975 | + cv_model.outcome_transform(train_Y[:-1]) |
| 976 | + self.assertTrue(torch.all(cv_model.outcome_transform.means == sd_mean)) |
| 977 | + |
| 978 | + self.assertTrue( |
| 979 | + torch.allclose( |
| 980 | + cv_model.input_transform._offset, |
| 981 | + state_dict["input_transform._offset"], |
| 982 | + ) |
| 983 | + ) |
| 984 | + self.assertTrue( |
| 985 | + torch.allclose( |
| 986 | + cv_model.outcome_transform.means, |
| 987 | + state_dict["outcome_transform.means"], |
| 988 | + ) |
| 989 | + ) |
| 990 | + |
| 991 | + self.assertAllClose(cv_model.train_targets, original_train_targets[..., :-1]) |
| 992 | + self.assertTrue( |
| 993 | + torch.equal( |
| 994 | + cv_model.input_transform(cv_model.train_inputs[0]), |
| 995 | + original_train_inputs[..., :-1, :], |
| 996 | + ) |
| 997 | + ) |
| 998 | + if include_yvar: |
| 999 | + self.assertAllClose( |
| 1000 | + cv_model.likelihood.noise_covar.noise, original_train_yvar[..., :-1] |
| 1001 | + ) |
| 1002 | + |
| 1003 | + cv_model = SingleTaskGP( |
| 1004 | + **cv_model_kwargs, **_get_input_output_transform(d=2, m=num_outputs) |
| 1005 | + ) |
| 1006 | + cv_model.load_state_dict(state_dict, keep_transforms=False) |
| 1007 | + |
| 1008 | + sd_mean = cv_model.outcome_transform.means |
| 1009 | + cv_model.outcome_transform(train_Y[:-1]) |
| 1010 | + self.assertTrue(torch.all(cv_model.outcome_transform.means != sd_mean)) |
| 1011 | + |
| 1012 | + self.assertFalse( |
| 1013 | + torch.equal( |
| 1014 | + cv_model.input_transform(cv_model.train_inputs[0]), |
| 1015 | + original_train_inputs[..., :-1, :], |
| 1016 | + ) |
| 1017 | + ) |
| 1018 | + self.assertFalse( |
| 1019 | + torch.equal(cv_model.train_targets, original_train_targets[..., :-1]) |
| 1020 | + ) |
| 1021 | + self.assertFalse( |
| 1022 | + torch.equal( |
| 1023 | + cv_model.input_transform._offset, |
| 1024 | + state_dict["input_transform._offset"], |
| 1025 | + ) |
| 1026 | + ) |
| 1027 | + self.assertFalse( |
| 1028 | + torch.equal( |
| 1029 | + cv_model.outcome_transform.means, |
| 1030 | + state_dict["outcome_transform.means"], |
| 1031 | + ) |
| 1032 | + ) |
| 1033 | + |
| 1034 | + def test_load_state_dict_with_transforms(self): |
| 1035 | + self._test_load_state_dict_base(num_outputs=1, include_yvar=True) |
| 1036 | + |
| 1037 | + def test_load_state_dict_with_transforms_no_yvar(self): |
| 1038 | + self._test_load_state_dict_base(num_outputs=1, include_yvar=False) |
| 1039 | + |
| 1040 | + def test_load_state_dict_multi_output_with_transforms(self): |
| 1041 | + self._test_load_state_dict_base(num_outputs=3, include_yvar=True) |
| 1042 | + |
| 1043 | + def test_load_state_dict_multi_output_with_transforms_no_yvar(self): |
| 1044 | + self._test_load_state_dict_base(num_outputs=3, include_yvar=False) |
| 1045 | + |
| 1046 | + def test_load_state_dict_no_transforms(self): |
| 1047 | + tkwargs = {"device": self.device, "dtype": torch.double} |
| 1048 | + |
| 1049 | + train_X = torch.rand(3, 2, **tkwargs) |
| 1050 | + train_X = torch.cat( |
| 1051 | + [train_X, torch.tensor([[-0.02, 11.1], [17.1, -2.5]], **tkwargs)], dim=0 |
| 1052 | + ) |
| 1053 | + train_Y = torch.sin(train_X).sum(dim=1, keepdim=True) |
| 1054 | + |
| 1055 | + base_model = SingleTaskGP( |
| 1056 | + train_X=train_X, train_Y=train_Y, outcome_transform=None |
| 1057 | + ) |
| 1058 | + original_train_targets = base_model.train_targets.clone() |
| 1059 | + state_dict = base_model.state_dict() |
| 1060 | + |
| 1061 | + cv_model = SingleTaskGP( |
| 1062 | + train_X=train_X[:-1], train_Y=train_Y[:-1], outcome_transform=None |
| 1063 | + ) |
| 1064 | + cv_model.load_state_dict(state_dict, keep_transforms=False) |
| 1065 | + |
| 1066 | + self.assertTrue( |
| 1067 | + torch.equal(cv_model.train_targets, original_train_targets[:-1]) |
| 1068 | + ) |
0 commit comments