|
31 | 31 | """ |
32 | 32 |
|
33 | 33 | import math |
34 | | -from abc import abstractmethod |
| 34 | +from abc import ABC, abstractmethod |
35 | 35 | from collections.abc import Mapping |
36 | 36 | from typing import Any |
37 | 37 |
|
@@ -311,14 +311,13 @@ def load_mcmc_samples( |
311 | 311 | return mean_module, covar_module, likelihood |
312 | 312 |
|
313 | 313 |
|
314 | | -class SaasFullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel): |
315 | | - r"""A fully Bayesian single-task GP model with the SAAS prior. |
| 314 | +class FullyBayesianSingleTaskGP(ExactGP, BatchedMultiOutputGPyTorchModel, ABC): |
| 315 | + r"""An abstract fully Bayesian single-task GP model. |
316 | 316 |
|
317 | 317 | This model assumes that the inputs have been normalized to [0, 1]^d and that |
318 | 318 | the output has been standardized to have zero mean and unit variance. You can |
319 | 319 | either normalize and standardize the data before constructing the model or use |
320 | | - an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_ |
321 | | - with a Matern-5/2 kernel is used by default. |
| 320 | + an `input_transform` and `outcome_transform`. |
322 | 321 |
|
323 | 322 | You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it |
324 | 323 | isn't compatible with `fit_gpytorch_mll`. |
@@ -412,17 +411,9 @@ def _check_if_fitted(self): |
412 | 411 | ) |
413 | 412 |
|
414 | 413 | @property |
415 | | - def median_lengthscale(self) -> Tensor: |
416 | | - r"""Median lengthscales across the MCMC samples.""" |
417 | | - self._check_if_fitted() |
418 | | - lengthscale = self.covar_module.base_kernel.lengthscale.clone() |
419 | | - return lengthscale.median(0).values.squeeze(0) |
420 | | - |
421 | | - @property |
| 414 | + @abstractmethod |
422 | 415 | def num_mcmc_samples(self) -> int: |
423 | 416 | r"""Number of MCMC samples in the model.""" |
424 | | - self._check_if_fitted() |
425 | | - return len(self.covar_module.outputscale) |
426 | 417 |
|
427 | 418 | @property |
428 | 419 | def batch_shape(self) -> torch.Size: |
@@ -459,41 +450,6 @@ def load_mcmc_samples(self, mcmc_samples: dict[str, Tensor]) -> None: |
459 | 450 | self.likelihood, |
460 | 451 | ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) |
461 | 452 |
|
462 | | - def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): |
463 | | - r"""Custom logic for loading the state dict. |
464 | | -
|
465 | | - The standard approach of calling `load_state_dict` currently doesn't play well |
466 | | - with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module` |
467 | | - and `likelihood` aren't initialized until the model has been fitted. The reason |
468 | | - for this is that we don't know the number of MCMC samples until NUTS is called. |
469 | | - Given the state dict, we can initialize a new model with some dummy samples and |
470 | | - then load the state dict into this model. This currently only works for a |
471 | | - `SaasPyroModel` and supporting more Pyro models likely requires moving the model |
472 | | - construction logic into the Pyro model itself. |
473 | | - """ |
474 | | - |
475 | | - if not isinstance(self.pyro_model, SaasPyroModel): |
476 | | - raise NotImplementedError("load_state_dict only works for SaasPyroModel") |
477 | | - raw_mean = state_dict["mean_module.raw_constant"] |
478 | | - num_mcmc_samples = len(raw_mean) |
479 | | - dim = self.pyro_model.train_X.shape[-1] |
480 | | - tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} |
481 | | - # Load some dummy samples |
482 | | - mcmc_samples = { |
483 | | - "mean": torch.ones(num_mcmc_samples, **tkwargs), |
484 | | - "lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs), |
485 | | - "outputscale": torch.ones(num_mcmc_samples, **tkwargs), |
486 | | - } |
487 | | - if self.pyro_model.train_Yvar is None: |
488 | | - mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) |
489 | | - ( |
490 | | - self.mean_module, |
491 | | - self.covar_module, |
492 | | - self.likelihood, |
493 | | - ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) |
494 | | - # Load the actual samples from the state dict |
495 | | - super().load_state_dict(state_dict=state_dict, strict=strict) |
496 | | - |
497 | 453 | def forward(self, X: Tensor) -> MultivariateNormal: |
498 | 454 | """ |
499 | 455 | Unlike in other classes' `forward` methods, there is no `if self.training` |
@@ -579,3 +535,70 @@ def condition_on_observations( |
579 | 535 | X = X.repeat(*(Y.shape[:-2] + (1, 1))) |
580 | 536 |
|
581 | 537 | return super().condition_on_observations(X, Y, **kwargs) |
| 538 | + |
| 539 | + |
| 540 | +class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP): |
| 541 | + r"""A fully Bayesian single-task GP model with the SAAS prior. |
| 542 | +
|
| 543 | + This model assumes that the inputs have been normalized to [0, 1]^d and that |
| 544 | + the output has been standardized to have zero mean and unit variance. You can |
| 545 | + either normalize and standardize the data before constructing the model or use |
| 546 | + an `input_transform` and `outcome_transform`. The SAAS model [Eriksson2021saasbo]_ |
| 547 | + with a Matern-5/2 kernel is used by default. |
| 548 | +
|
| 549 | + You are expected to use `fit_fully_bayesian_model_nuts` to fit this model as it |
| 550 | + isn't compatible with `fit_gpytorch_mll`. |
| 551 | +
|
| 552 | + Example: |
| 553 | + >>> saas_gp = SaasFullyBayesianSingleTaskGP(train_X, train_Y) |
| 554 | + >>> fit_fully_bayesian_model_nuts(saas_gp) |
| 555 | + >>> posterior = saas_gp.posterior(test_X) |
| 556 | + """ |
| 557 | + |
| 558 | + @property |
| 559 | + def num_mcmc_samples(self) -> int: |
| 560 | + r"""Number of MCMC samples in the model.""" |
| 561 | + self._check_if_fitted() |
| 562 | + return len(self.covar_module.outputscale) |
| 563 | + |
| 564 | + @property |
| 565 | + def median_lengthscale(self) -> Tensor: |
| 566 | + r"""Median lengthscales across the MCMC samples.""" |
| 567 | + self._check_if_fitted() |
| 568 | + lengthscale = self.covar_module.base_kernel.lengthscale.clone() |
| 569 | + return lengthscale.median(0).values.squeeze(0) |
| 570 | + |
| 571 | + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): |
| 572 | + r"""Custom logic for loading the state dict. |
| 573 | +
|
| 574 | + The standard approach of calling `load_state_dict` currently doesn't play well |
| 575 | + with the `SaasFullyBayesianSingleTaskGP` since the `mean_module`, `covar_module` |
| 576 | + and `likelihood` aren't initialized until the model has been fitted. The reason |
| 577 | + for this is that we don't know the number of MCMC samples until NUTS is called. |
| 578 | + Given the state dict, we can initialize a new model with some dummy samples and |
| 579 | + then load the state dict into this model. This currently only works for a |
| 580 | + `SaasPyroModel` and supporting more Pyro models likely requires moving the model |
| 581 | + construction logic into the Pyro model itself. |
| 582 | + """ |
| 583 | + |
| 584 | + if not isinstance(self.pyro_model, SaasPyroModel): |
| 585 | + raise NotImplementedError("load_state_dict only works for SaasPyroModel") |
| 586 | + raw_mean = state_dict["mean_module.raw_constant"] |
| 587 | + num_mcmc_samples = len(raw_mean) |
| 588 | + dim = self.pyro_model.train_X.shape[-1] |
| 589 | + tkwargs = {"device": raw_mean.device, "dtype": raw_mean.dtype} |
| 590 | + # Load some dummy samples |
| 591 | + mcmc_samples = { |
| 592 | + "mean": torch.ones(num_mcmc_samples, **tkwargs), |
| 593 | + "lengthscale": torch.ones(num_mcmc_samples, dim, **tkwargs), |
| 594 | + "outputscale": torch.ones(num_mcmc_samples, **tkwargs), |
| 595 | + } |
| 596 | + if self.pyro_model.train_Yvar is None: |
| 597 | + mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) |
| 598 | + ( |
| 599 | + self.mean_module, |
| 600 | + self.covar_module, |
| 601 | + self.likelihood, |
| 602 | + ) = self.pyro_model.load_mcmc_samples(mcmc_samples=mcmc_samples) |
| 603 | + # Load the actual samples from the state dict |
| 604 | + super().load_state_dict(state_dict=state_dict, strict=strict) |
0 commit comments