Skip to content

Commit

Permalink
Supporting input_transform in SparseOutlierLikelihood
Browse files Browse the repository at this point in the history
Summary:
Adding support to using the `SparseOutlierLikelihood` in conjunction with input transforms. 

In `eval` model, BoTorch applies input transforms in the posterior call. For this reason, the likelihood will receive un-transformed inputs during training, but transformed inputs during inference. So we need to make sure to store the transformed inputs in the training data cache of `SparseOutlierNoise` for inference comparisons.

Differential Revision: D67605578
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Dec 23, 2024
1 parent 7715ff4 commit 3ca453d
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
16 changes: 16 additions & 0 deletions botorch/models/likelihoods/sparse_outlier_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.model import Model
from botorch.models.relevance_pursuit import RelevancePursuitMixin
from botorch.models.transforms.input import InputTransform
from botorch.utils.constraints import NonTransformedInterval
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import _GaussianLikelihoodBase
Expand All @@ -28,6 +29,7 @@ def __init__(
self,
base_noise: Noise | FixedGaussianNoise,
dim: int,
input_transform: InputTransform | None,
outlier_indices: list[int] | None = None,
rho_prior: Prior | None = None,
rho_constraint: NonTransformedInterval | None = None,
Expand Down Expand Up @@ -68,6 +70,8 @@ def __init__(
base_noise: The base noise model.
dim: The number of training observations, which determines the maximum
number of data-point-specific noise variances of the noise model.
input_transform: An input transform to be applied to the input data. This
should be the same transform that is used in the Gaussian process model.
outlier_indices: The indices of the outliers.
rho_prior: Prior for `self.noise_covar`'s rho parameter.
rho_constraint: Constraint for `self.noise_covar`'s rho parameter. Needs to
Expand All @@ -82,6 +86,7 @@ def __init__(
noise_covar = SparseOutlierNoise(
base_noise=base_noise,
dim=dim,
input_transform=input_transform,
outlier_indices=outlier_indices,
rho_prior=rho_prior,
rho_constraint=rho_constraint,
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(
self,
base_noise: Noise | FixedGaussianNoise,
dim: int,
input_transform: InputTransform | None,
outlier_indices: list[int] | None = None,
rho_prior: Prior | None = None,
rho_constraint: NonTransformedInterval | None = None,
Expand Down Expand Up @@ -155,6 +161,8 @@ def __init__(
base_noise: The base noise model.
dim: The number of training observations, which determines the maximum
number of data-point-specific noise variances of the noise model.
input_transform: An input transform to be applied to the input data. This
should be the same transform that is used in the Gaussian process model.
outlier_indices: The indices of the outliers.
rho_prior: Prior for the rho parameter.
rho_constraint: Constraint for the rho parameter. Needs to be a
Expand Down Expand Up @@ -232,6 +240,7 @@ def _rho_param(m):
# with the rho constraints.
self._convex_parameterization = convex_parameterization
self.loo = loo
self.input_transform = input_transform
self._cached_train_inputs = None

@property
Expand Down Expand Up @@ -401,6 +410,13 @@ def forward(
)
elif self.training or self._cached_train_inputs is None:
apply_robust_variances = True
# NOTE: BoTorch input transforms are applied in the model's `forward`
# in `train` mode and in `posterior` in `eval` mode. For this reason,
# the likelihood will receive un-transformed inputs during training,
# but transformed inputs during inference, so we need to make sure to
# store the transformed inputs in the cache for inference comparisons.
if self.input_transform is not None:
X = self.input_transform.transform(X)
self._cached_train_inputs = X
warning_reason = "" # will not warn when applying robust variances
else:
Expand Down
29 changes: 24 additions & 5 deletions test/models/test_relevance_pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations

import itertools
import warnings

from functools import partial

Expand All @@ -26,7 +27,7 @@
get_posterior_over_support,
RelevancePursuitMixin,
)
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.test_functions.base import constant_outlier_generator, CorruptedTestProblem

Expand Down Expand Up @@ -94,6 +95,7 @@ def _get_robust_model(
X: Tensor,
Y: Tensor,
likelihood: SparseOutlierGaussianLikelihood,
input_transform: InputTransform | None,
) -> SingleTaskGP:
min_lengthscale = 0.1
lengthscale_constraint = NonTransformedInterval(
Expand All @@ -113,7 +115,7 @@ def _get_robust_model(
train_Y=Y,
mean_module=ZeroMean(),
covar_module=kernel,
input_transform=Normalize(d=X.shape[-1]),
input_transform=input_transform,
outcome_transform=Standardize(m=Y.shape[-1]),
likelihood=likelihood,
)
Expand Down Expand Up @@ -145,14 +147,17 @@ def _test_robust_gp_end_to_end(
min_noise, max_noise, initial_value=1e-3
)
).to(dtype=dtype, device=self.device)

input_transform = Normalize(d=X.shape[-1])
rp_likelihood = SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=input_transform,
convex_parameterization=convex_parameterization,
)

model = self._get_robust_model(X=X, Y=Y, likelihood=rp_likelihood)
model = self._get_robust_model(
X=X, Y=Y, likelihood=rp_likelihood, input_transform=input_transform
)

X_test = torch.rand(3, 1, dtype=dtype, device=self.device)
with self.assertWarnsRegex(InputDataWarning, "SparseOutlierNoise"):
Expand Down Expand Up @@ -197,6 +202,11 @@ def _test_robust_gp_end_to_end(
undetected_outliers = set(outlier_indices) - set(sparse_module.support)
self.assertEqual(len(undetected_outliers), 0)

# testing that posterior inference on training set does not throw warnings
with warnings.catch_warnings(record=True) as warnings_log:
map_model.posterior(X)
self.assertEqual(warnings_log, [])

def test_robust_relevance_pursuit(self) -> None:
for optimizer, convex_parameterization, dtype in itertools.product(
[forward_relevance_pursuit, backward_relevance_pursuit],
Expand Down Expand Up @@ -249,6 +259,7 @@ def _test_robust_relevance_pursuit(
SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=None,
convex_parameterization=convex_parameterization,
rho_constraint=Interval(0.0, 1.0), # pyre-ignore[6]
)
Expand All @@ -257,6 +268,7 @@ def _test_robust_relevance_pursuit(
SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=None,
convex_parameterization=convex_parameterization,
rho_constraint=NonTransformedInterval(-1.0, 1.0),
)
Expand All @@ -266,6 +278,7 @@ def _test_robust_relevance_pursuit(
SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=None,
convex_parameterization=convex_parameterization,
rho_constraint=NonTransformedInterval(0.0, 2.0),
loo=loo,
Expand All @@ -274,6 +287,7 @@ def _test_robust_relevance_pursuit(
likelihood_with_other_bounds = SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=None,
convex_parameterization=convex_parameterization,
rho_constraint=NonTransformedInterval(0.0, 2.0),
loo=loo,
Expand All @@ -285,6 +299,7 @@ def _test_robust_relevance_pursuit(
rp_likelihood = SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=None,
convex_parameterization=convex_parameterization,
loo=loo,
)
Expand All @@ -303,9 +318,11 @@ def _test_robust_relevance_pursuit(
rp_likelihood.expected_log_prob(target=None, input=None) # pyre-ignore[6]

# testing prior initialization
input_transform = None
likelihood_with_prior = SparseOutlierGaussianLikelihood(
base_noise=base_noise,
dim=X.shape[0],
input_transform=input_transform,
convex_parameterization=convex_parameterization,
rho_prior=gpytorch.priors.NormalPrior(loc=1 / 2, scale=0.1),
loo=loo,
Expand All @@ -316,7 +333,9 @@ def _test_robust_relevance_pursuit(

# combining likelihood with rho prior and full GP model
# this will test the prior code paths when computing the marginal likelihood
model = self._get_robust_model(X=X, Y=Y, likelihood=likelihood_with_prior)
model = self._get_robust_model(
X=X, Y=Y, likelihood=likelihood_with_prior, input_transform=input_transform
)

# testing the _from_model method
with self.assertRaisesRegex(
Expand Down

0 comments on commit 3ca453d

Please sign in to comment.