Skip to content

Make (Log)NoisyExpectedImprovement create a correct fantasy model with non-default SingleTaskGP #2414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
69 changes: 52 additions & 17 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
import math

from abc import ABC

from contextlib import nullcontext
from copy import deepcopy

from typing import Dict, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -617,15 +615,17 @@ def __init__(
r"""Single-outcome Noisy Log Expected Improvement (via fantasies).

Args:
model: A fitted single-outcome model.
model: A fitted single-outcome model. Only `SingleTaskGP` models with
known observation noise are currently supported.
X_observed: A `n x d` Tensor of observed points that are likely to
be the best observed points so far.
num_fantasies: The number of fantasies to generate. The higher this
number the more accurate the model (at the expense of model
complexity and performance).
maximize: If True, consider the problem a maximization problem.
"""
# sample fantasies
_check_noisy_ei_model(model=model)
# Sample fantasies.
from botorch.sampling.normal import SobolQMCNormalSampler

# Drop gradients from model.posterior if X_observed does not require gradients
Expand Down Expand Up @@ -699,16 +699,18 @@ def __init__(
r"""Single-outcome Noisy Expected Improvement (via fantasies).

Args:
model: A fitted single-outcome model.
model: A fitted single-outcome model. Only `SingleTaskGP` models with
known observation noise are currently supported.
X_observed: A `n x d` Tensor of observed points that are likely to
be the best observed points so far.
num_fantasies: The number of fantasies to generate. The higher this
number the more accurate the model (at the expense of model
complexity and performance).
maximize: If True, consider the problem a maximization problem.
"""
_check_noisy_ei_model(model=model)
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
# sample fantasies
# Sample fantasies.
from botorch.sampling.normal import SobolQMCNormalSampler

# Drop gradients from model.posterior if X_observed does not require gradients
Expand Down Expand Up @@ -1055,6 +1057,21 @@ def logerfcx(x: Tensor) -> Tensor:
return torch.log(torch.special.erfcx(a * u) * u.abs()) + b


def _check_noisy_ei_model(model: GPyTorchModel) -> None:
message = (
"Only single-output `SingleTaskGP` models with known observation noise "
"are currently supported for fantasy-based NEI & LogNEI."
)
if not isinstance(model, SingleTaskGP):
raise UnsupportedError(f"{message} Model is not a `SingleTaskGP`.")
if not isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
raise UnsupportedError(
f"{message} Model likelihood is not a `FixedNoiseGaussianLikelihood`."
)
if model.num_outputs != 1:
raise UnsupportedError(f"{message} Model has {model.num_outputs} outputs.")


def _get_noiseless_fantasy_model(
model: SingleTaskGP, batch_X_observed: Tensor, Y_fantasized: Tensor
) -> SingleTaskGP:
Expand All @@ -1073,31 +1090,49 @@ def _get_noiseless_fantasy_model(
Returns:
The fantasy model.
"""
if not isinstance(model, SingleTaskGP) or not isinstance(
model.likelihood, FixedNoiseGaussianLikelihood
):
raise UnsupportedError(
"Only SingleTaskGP models with known observation noise "
"are currently supported for fantasy-based NEI & LogNEI."
)
# initialize a copy of SingleTaskGP on the original training inputs
# this makes SingleTaskGP a non-batch GP, so that the same hyperparameters
# are used across all batches (by default, a GP with batched training data
# uses independent hyperparameters for each batch).

# Don't apply `outcome_transform` and `input_transform` here,
# since the data being passed has already been transformed.
# So we will instead set them afterwards.
fantasy_model = SingleTaskGP(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change would allow for supporting more model types. However, @saitcakmak had a good question: Why do we need _get_noiseless_fantasy_model at all? Can we use the fantasize method on the model instead? I'm a bit afraid of the change I'm suggesting, since this instantiation logic won't be right for every model.

Suggested change
fantasy_model = SingleTaskGP(
fantasy_model = cls(model)(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it'd be a great simplification (& removal of duplicate logic) if we can simply use model.fantasize(...) rather than defining a custom _get_noiseless_fantasy_model.

Copy link
Contributor Author

@71c 71c Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean to write model.__class__(...) or type(model)(...); cls(model) is not a thing in Python.
But anyway, how about just not making that change because the model is currently assumed to be a SingleTaskGP anyway, and if we can find a way to use fantasize for a wider variety of models later, then just do that.

train_X=model.train_inputs[0],
train_Y=model.train_targets.unsqueeze(-1),
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
covar_module=deepcopy(model.covar_module),
mean_module=deepcopy(model.mean_module),
)

Yvar = torch.full_like(Y_fantasized, 1e-7)

# Set the outcome and input transforms of the fantasy model.
# The transforms should already be in eval mode but just set them to be sure
outcome_transform = getattr(model, "outcome_transform", None)
if outcome_transform is not None:
outcome_transform = deepcopy(outcome_transform).eval()
fantasy_model.outcome_transform = outcome_transform
# Need to transform the outcome just as in the SingleTaskGP constructor.
# Need to unsqueeze for BoTorch and then squeeze again for GPyTorch.
# Not transforming Yvar because 1e-7 is already close to 0 and it is a
# relative, not absolute, value.
Y_fantasized, _ = outcome_transform(
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
)
Y_fantasized = Y_fantasized.squeeze(-1)
input_transform = getattr(model, "input_transform", None)
if input_transform is not None:
fantasy_model.input_transform = deepcopy(input_transform).eval()

# update training inputs/targets to be batch mode fantasies
fantasy_model.set_train_data(
inputs=batch_X_observed, targets=Y_fantasized, strict=False
)
# use noiseless fantasies
fantasy_model.likelihood.noise_covar.noise = torch.full_like(Y_fantasized, 1e-7)
# load hyperparameters from original model
state_dict = deepcopy(model.state_dict())
fantasy_model.load_state_dict(state_dict)
fantasy_model.likelihood.noise_covar.noise = Yvar

return fantasy_model


Expand Down
151 changes: 132 additions & 19 deletions test/acquisition/test_analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import math
from warnings import catch_warnings, simplefilter

import torch
from botorch.acquisition import qAnalyticProbabilityOfImprovement
from botorch.acquisition.analytic import (
_check_noisy_ei_model,
_compute_log_prob_feas,
_ei_helper,
_log_ei_helper,
Expand All @@ -33,11 +35,19 @@
)
from botorch.exceptions import UnsupportedError
from botorch.exceptions.warnings import NumericsWarning
from botorch.models import SingleTaskGP
from botorch.models import ModelListGP, SingleTaskGP
from botorch.models.transforms import ChainedOutcomeTransform, Normalize, Standardize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.pathwise.utils import get_train_inputs
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods.gaussian_likelihood import (
FixedNoiseGaussianLikelihood,
GaussianLikelihood,
)
from gpytorch.module import Module
from gpytorch.priors.torch_priors import GammaPrior


NEI_NOISE = [
Expand Down Expand Up @@ -831,7 +841,15 @@ def _test_constrained_expected_improvement_batch(self, dtype: torch.dtype) -> No


class TestNoisyExpectedImprovement(BotorchTestCase):
def _get_model(self, dtype=torch.float):
def _get_model(
self,
dtype=torch.float,
outcome_transform=None,
input_transform=None,
low_x=0.0,
hi_x=1.0,
covar_module=None,
) -> SingleTaskGP:
state_dict = {
"mean_module.raw_constant": torch.tensor([-0.0066]),
"covar_module.raw_outputscale": torch.tensor(1.0143),
Expand All @@ -843,20 +861,31 @@ def _get_model(self, dtype=torch.float):
"covar_module.outputscale_prior.concentration": torch.tensor(2.0),
"covar_module.outputscale_prior.rate": torch.tensor(0.1500),
}
train_x = torch.linspace(0, 1, 10, device=self.device, dtype=dtype).unsqueeze(
-1
)
train_x = torch.linspace(
0.0, 1.0, 10, device=self.device, dtype=dtype
).unsqueeze(-1)
# Taking the sin of the *transformed* input to make the test equivalent
# to when there are no input transforms
train_y = torch.sin(train_x * (2 * math.pi))
# Now transform the input to be passed into SingleTaskGP constructor
train_x = train_x * (hi_x - low_x) + low_x
noise = torch.tensor(NEI_NOISE, device=self.device, dtype=dtype)
train_y += noise
train_yvar = torch.full_like(train_y, 0.25**2)
model = SingleTaskGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar)
model.load_state_dict(state_dict)
model = SingleTaskGP(
train_X=train_x,
train_Y=train_y,
train_Yvar=train_yvar,
outcome_transform=outcome_transform,
input_transform=input_transform,
covar_module=covar_module,
)
model.load_state_dict(state_dict, strict=False)
model.to(train_x)
model.eval()
return model

def test_noisy_expected_improvement(self):
def test_noisy_expected_improvement(self) -> None:
model = self._get_model(dtype=torch.float64)
X_observed = model.train_inputs[0]
nfan = 5
Expand All @@ -865,14 +894,75 @@ def test_noisy_expected_improvement(self):
):
NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)

for dtype in (torch.float, torch.double):
# Same as the default Matern kernel
# botorch.models.utils.gpytorch_modules.get_matern_kernel_with_gamma_prior,
# except RBFKernel is used instead of MaternKernel.
# For some reason, RBF gives numerical problems with torch.float but
# Matern does not. Therefore, we'll skip the test for RBF when dtype is
# torch.float.
covar_module_2 = ScaleKernel(
base_kernel=RBFKernel(
ard_num_dims=1,
batch_shape=torch.Size(),
lengthscale_prior=GammaPrior(3.0, 6.0),
),
batch_shape=torch.Size(),
outputscale_prior=GammaPrior(2.0, 0.15),
)
for dtype, use_octf, use_intf, bounds, covar_module in itertools.product(
(torch.float, torch.double),
(False, True),
(False, True),
(torch.tensor([[-3.4], [0.8]]), torch.tensor([[0.0], [1.0]])),
(None, covar_module_2),
):
with catch_warnings():
simplefilter("ignore", category=NumericsWarning)
self._test_noisy_expected_imrpovement(dtype)
self._test_noisy_expected_improvement(
dtype=dtype,
use_octf=use_octf,
use_intf=use_intf,
bounds=bounds,
covar_module=covar_module,
)

def _test_noisy_expected_improvement(
self,
dtype: torch.dtype,
use_octf: bool,
use_intf: bool,
bounds: torch.Tensor,
covar_module: Module,
) -> None:
if covar_module is not None and dtype == torch.float:
# Skip this test because RBF runs into numerical problems with float
# precision
return
octf = (
ChainedOutcomeTransform(standardize=Standardize(m=1)) if use_octf else None
)
intf = (
Normalize(
d=1,
bounds=bounds.to(device=self.device, dtype=dtype),
transform_on_train=True,
)
if use_intf
else None
)
low_x = bounds[0].item() if use_intf else 0.0
hi_x = bounds[1].item() if use_intf else 1.0
model = self._get_model(
dtype=dtype,
outcome_transform=octf,
input_transform=intf,
low_x=low_x,
hi_x=hi_x,
covar_module=covar_module,
)
# Make sure to get the non-transformed training inputs.
X_observed = get_train_inputs(model, transformed=False)[0]

def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
model = self._get_model(dtype=dtype)
X_observed = model.train_inputs[0]
nfan = 5
nEI = NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
LogNEI = LogNoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
Expand All @@ -881,6 +971,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
self.assertTrue(hasattr(LogNEI, "best_f"))
self.assertIsInstance(LogNEI.model, SingleTaskGP)
self.assertIsInstance(LogNEI.model.likelihood, FixedNoiseGaussianLikelihood)
# Make sure _get_noiseless_fantasy_model gives them
# the same state_dict
self.assertEqual(LogNEI.model.state_dict(), model.state_dict())

LogNEI.model = nEI.model # let the two share their values and fantasies
LogNEI.best_f = nEI.best_f

Expand All @@ -892,9 +986,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
X_test_log = X_test.clone()
X_test.requires_grad = True
X_test_log.requires_grad = True
val = nEI(X_test)

val = nEI(X_test * (hi_x - low_x) + low_x)
# testing logNEI yields the same result (also checks dtype)
log_val = LogNEI(X_test_log)
log_val = LogNEI(X_test_log * (hi_x - low_x) + low_x)
exp_log_val = log_val.exp()
# notably, val[1] is usually zero in this test, which is precisely what
# gives rise to problems during optimization, and what logNEI avoids
Expand All @@ -916,7 +1011,7 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
# testing gradient through exp of log computation
exp_log_val.sum().backward()
# testing that first gradient element coincides. The second is in the
# regime where the naive implementation looses accuracy.
# regime where the naive implementation loses accuracy.
atol = 2e-5 if dtype == torch.float32 else 1e-12
rtol = atol
self.assertAllClose(X_test.grad[0], X_test_log.grad[0], atol=atol, rtol=rtol)
Expand Down Expand Up @@ -945,9 +1040,27 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
acqf = constructor(model, X_observed, num_fantasies=5)
self.assertTrue(acqf.best_f.requires_grad)

def test_check_noisy_ei_model(self) -> None:
tkwargs = {"dtype": torch.double, "device": self.device}
# Multi-output model.
model = SingleTaskGP(
train_X=torch.rand(5, 2, **tkwargs),
train_Y=torch.rand(5, 2, **tkwargs),
train_Yvar=torch.rand(5, 2, **tkwargs),
)
with self.assertRaisesRegex(UnsupportedError, "Model has 2 outputs"):
_check_noisy_ei_model(model=model)
# Not SingleTaskGP.
with self.assertRaisesRegex(UnsupportedError, "Model is not"):
_check_noisy_ei_model(model=ModelListGP(model))
# Not fixed noise.
model.likelihood = GaussianLikelihood()
with self.assertRaisesRegex(UnsupportedError, "Model likelihood is not"):
_check_noisy_ei_model(model=model)


class TestScalarizedPosteriorMean(BotorchTestCase):
def test_scalarized_posterior_mean(self):
def test_scalarized_posterior_mean(self) -> None:
for dtype in (torch.float, torch.double):
mean = torch.tensor([[0.25], [0.5]], device=self.device, dtype=dtype)
mm = MockModel(MockPosterior(mean=mean))
Expand All @@ -959,7 +1072,7 @@ def test_scalarized_posterior_mean(self):
torch.allclose(pm, (mean.squeeze(-1) * module.weights).sum(dim=-1))
)

def test_scalarized_posterior_mean_batch(self):
def test_scalarized_posterior_mean_batch(self) -> None:
for dtype in (torch.float, torch.double):
mean = torch.tensor(
[[-0.5, 1.0], [0.0, 1.0], [0.5, 1.0]], device=self.device, dtype=dtype
Expand Down
Loading