|
17 | 17 | import warnings |
18 | 18 | from abc import ABC |
19 | 19 | from copy import deepcopy |
20 | | -from typing import Any, TYPE_CHECKING |
| 20 | +from typing import Any, Mapping, TYPE_CHECKING |
21 | 21 |
|
22 | 22 | import torch |
23 | 23 | from botorch.acquisition.objective import PosteriorTransform |
@@ -283,6 +283,111 @@ def condition_on_observations( |
283 | 283 | ).detach() |
284 | 284 | return fantasy_model |
285 | 285 |
|
| 286 | + def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]: |
| 287 | + r"""Extract targets and noise variance in the correct shape. |
| 288 | +
|
| 289 | + Returns a tuple of (Y, Yvar) where Y and Yvar have shape |
| 290 | + [batch_shape] x n x m, with batch_shape included only if the |
| 291 | + training data initially contained it. |
| 292 | + """ |
| 293 | + if self.num_outputs > 1: |
| 294 | + Y = self.train_targets.transpose(-1, -2) |
| 295 | + Yvar = None |
| 296 | + if isinstance(self.likelihood, FixedNoiseGaussianLikelihood): |
| 297 | + Yvar = self.likelihood.noise_covar.noise.transpose(-1, -2) |
| 298 | + else: |
| 299 | + Y = self.train_targets.unsqueeze(-1) |
| 300 | + Yvar = None |
| 301 | + if isinstance(self.likelihood, FixedNoiseGaussianLikelihood): |
| 302 | + Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1) |
| 303 | + return Y, Yvar |
| 304 | + |
| 305 | + def _restore_targets_and_noise( |
| 306 | + self, Y: Tensor, Yvar: Tensor | None, strict: bool |
| 307 | + ) -> None: |
| 308 | + r"""Restore targets and noise variance to the model. |
| 309 | +
|
| 310 | + Args: |
| 311 | + Y: Targets tensor in shape [batch_shape] x n x m. |
| 312 | + Yvar: Optional noise variance tensor in shape [batch_shape] x n x m. |
| 313 | + strict: Whether to strictly enforce shape constraints. |
| 314 | + """ |
| 315 | + if self.num_outputs > 1: |
| 316 | + Y = Y.transpose(-1, -2) |
| 317 | + if Yvar is not None and isinstance( |
| 318 | + self.likelihood, FixedNoiseGaussianLikelihood |
| 319 | + ): |
| 320 | + Yvar = Yvar.transpose(-1, -2) |
| 321 | + self.likelihood.noise_covar.noise = Yvar |
| 322 | + else: |
| 323 | + Y = Y.squeeze(-1) |
| 324 | + if Yvar is not None and isinstance( |
| 325 | + self.likelihood, FixedNoiseGaussianLikelihood |
| 326 | + ): |
| 327 | + Yvar = Yvar.squeeze(-1) |
| 328 | + self.likelihood.noise_covar.noise = Yvar |
| 329 | + |
| 330 | + self.set_train_data(targets=Y, strict=strict) |
| 331 | + |
| 332 | + def load_state_dict( |
| 333 | + self, |
| 334 | + state_dict: Mapping[str, Any], |
| 335 | + strict: bool = True, |
| 336 | + keep_transforms: bool = True, |
| 337 | + ) -> None: |
| 338 | + r"""Load the model state. |
| 339 | +
|
| 340 | + Args: |
| 341 | + state_dict: A dict containing the state of the model. |
| 342 | + strict: A boolean indicating whether to strictly enforce that the keys. |
| 343 | + keep_transforms: A boolean indicating whether to keep the input and outcome |
| 344 | + transforms. Doing so is useful when loading a model that was trained on |
| 345 | + a full set of data, and is later loaded with a subset of the data. |
| 346 | + """ |
| 347 | + if not keep_transforms: |
| 348 | + super().load_state_dict(state_dict, strict) |
| 349 | + return |
| 350 | + |
| 351 | + should_outcome_transform = ( |
| 352 | + hasattr(self, "train_targets") |
| 353 | + and getattr(self, "outcome_transform", None) is not None |
| 354 | + ) |
| 355 | + |
| 356 | + with torch.no_grad(): |
| 357 | + untransformed_Y, untransformed_Yvar = self._extract_targets_and_noise() |
| 358 | + X = self.train_inputs[0] |
| 359 | + |
| 360 | + if should_outcome_transform: |
| 361 | + try: |
| 362 | + untransformed_Y, untransformed_Yvar = ( |
| 363 | + self.outcome_transform.untransform( |
| 364 | + Y=untransformed_Y, |
| 365 | + Yvar=untransformed_Yvar, |
| 366 | + X=X, |
| 367 | + ) |
| 368 | + ) |
| 369 | + except NotImplementedError: |
| 370 | + warnings.warn( |
| 371 | + "Outcome transform does not support untransforming." |
| 372 | + "Cannot load the state dict with transforms preserved." |
| 373 | + "Setting keep_transforms=False.", |
| 374 | + stacklevel=3, |
| 375 | + ) |
| 376 | + super().load_state_dict(state_dict, strict) |
| 377 | + return |
| 378 | + |
| 379 | + super().load_state_dict(state_dict, strict) |
| 380 | + |
| 381 | + if getattr(self, "input_transform", None) is not None: |
| 382 | + self.input_transform.eval() |
| 383 | + |
| 384 | + if should_outcome_transform: |
| 385 | + self.outcome_transform.eval() |
| 386 | + retransformed_Y, retransformed_Yvar = self.outcome_transform( |
| 387 | + Y=untransformed_Y, Yvar=untransformed_Yvar, X=X |
| 388 | + ) |
| 389 | + self._restore_targets_and_noise(retransformed_Y, retransformed_Yvar, strict) |
| 390 | + |
286 | 391 |
|
287 | 392 | # pyre-fixme[13]: uninitialized attributes _num_outputs, _input_batch_shape, |
288 | 393 | # _aug_batch_shape |
@@ -803,6 +908,38 @@ class MultiTaskGPyTorchModel(GPyTorchModel, ABC): |
803 | 908 | "long-format" multi-task GP in the style of `MultiTaskGP`. |
804 | 909 | """ |
805 | 910 |
|
| 911 | + def _extract_targets_and_noise(self) -> tuple[Tensor, Tensor | None]: |
| 912 | + r"""Extract targets and noise variance for multi-task models. |
| 913 | +
|
| 914 | + Returns a tuple of (Y, Yvar) where Y and Yvar have shape |
| 915 | + [batch_shape] x n x m, with batch_shape included only if the |
| 916 | + training data initially contained it. |
| 917 | + """ |
| 918 | + Y = self.train_targets.unsqueeze(-1) |
| 919 | + Yvar = None |
| 920 | + if isinstance(self.likelihood, FixedNoiseGaussianLikelihood): |
| 921 | + Yvar = self.likelihood.noise_covar.noise.unsqueeze(-1) |
| 922 | + return Y, Yvar |
| 923 | + |
| 924 | + def _restore_targets_and_noise( |
| 925 | + self, Y: Tensor, Yvar: Tensor | None, strict: bool |
| 926 | + ) -> None: |
| 927 | + r"""Restore targets and noise variance for multi-task models. |
| 928 | +
|
| 929 | + Args: |
| 930 | + Y: Targets tensor in shape [batch_shape] x n x m. |
| 931 | + Yvar: Optional noise variance tensor in shape [batch_shape] x n x m. |
| 932 | + strict: Whether to strictly enforce shape constraints. |
| 933 | + """ |
| 934 | + Y = Y.squeeze(-1) |
| 935 | + if Yvar is not None and isinstance( |
| 936 | + self.likelihood, FixedNoiseGaussianLikelihood |
| 937 | + ): |
| 938 | + Yvar = Yvar.squeeze(-1) |
| 939 | + self.likelihood.noise_covar.noise = Yvar |
| 940 | + |
| 941 | + self.set_train_data(targets=Y, strict=strict) |
| 942 | + |
806 | 943 | def _apply_noise( |
807 | 944 | self, |
808 | 945 | X: Tensor, |
|
0 commit comments