Skip to content
64 changes: 51 additions & 13 deletions botorch/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from botorch.utils.safe_math import log1mexp, logmeanexp
from botorch.utils.transforms import convert_to_target_pre_hook, t_batch_mode_transform
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood

from torch import Tensor
from torch.nn.functional import pad

Expand Down Expand Up @@ -617,14 +618,16 @@ def __init__(
r"""Single-outcome Noisy Log Expected Improvement (via fantasies).

Args:
model: A fitted single-outcome model.
model: A fitted single-outcome model. Only SingleTaskGP models with
known observation noise are currently supported.
X_observed: A `n x d` Tensor of observed points that are likely to
be the best observed points so far.
num_fantasies: The number of fantasies to generate. The higher this
number the more accurate the model (at the expense of model
complexity and performance).
maximize: If True, consider the problem a maximization problem.
"""
_check_noisyei_model(model)
# sample fantasies
from botorch.sampling.normal import SobolQMCNormalSampler

Expand Down Expand Up @@ -699,14 +702,16 @@ def __init__(
r"""Single-outcome Noisy Expected Improvement (via fantasies).

Args:
model: A fitted single-outcome model.
model: A fitted single-outcome model. Only SingleTaskGP models with
known observation noise are currently supported.
X_observed: A `n x d` Tensor of observed points that are likely to
be the best observed points so far.
num_fantasies: The number of fantasies to generate. The higher this
number the more accurate the model (at the expense of model
complexity and performance).
maximize: If True, consider the problem a maximization problem.
"""
_check_noisyei_model(model)
legacy_ei_numerics_warning(legacy_name=type(self).__name__)
# sample fantasies
from botorch.sampling.normal import SobolQMCNormalSampler
Expand Down Expand Up @@ -1055,6 +1060,21 @@ def logerfcx(x: Tensor) -> Tensor:
return torch.log(torch.special.erfcx(a * u) * u.abs()) + b


def _check_noisyei_model(model: GPyTorchModel) -> None:
message = (
"Only single-output SingleTaskGP models with known observation noise "
"are currently supported for fantasy-based NEI & LogNEI."
)
if not isinstance(model, SingleTaskGP):
raise UnsupportedError(f"{message} Model is not a SingleTaskGP.")
if not isinstance(model.likelihood, FixedNoiseGaussianLikelihood):
raise UnsupportedError(
f"{message} Model likelihood is not a FixedNoiseGaussianLikelihood."
)
if model.num_outputs != 1:
raise UnsupportedError(f"{message} Model has {model.num_outputs} outputs.")


def _get_noiseless_fantasy_model(
model: SingleTaskGP, batch_X_observed: Tensor, Y_fantasized: Tensor
) -> SingleTaskGP:
Expand All @@ -1073,31 +1093,49 @@ def _get_noiseless_fantasy_model(
Returns:
The fantasy model.
"""
if not isinstance(model, SingleTaskGP) or not isinstance(
model.likelihood, FixedNoiseGaussianLikelihood
):
raise UnsupportedError(
"Only SingleTaskGP models with known observation noise "
"are currently supported for fantasy-based NEI & LogNEI."
)
# initialize a copy of SingleTaskGP on the original training inputs
# this makes SingleTaskGP a non-batch GP, so that the same hyperparameters
# are used across all batches (by default, a GP with batched training data
# uses independent hyperparameters for each batch).

# Don't apply `outcome_transform` and `input_transform` here,
# since the data being passed has already been transformed.
# So we will instead set them afterwards.
fantasy_model = SingleTaskGP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change would allow for supporting more model types. However, @saitcakmak had a good question: Why do we need _get_noiseless_fantasy_model at all? Can we use the fantasize method on the model instead? I'm a bit afraid of the change I'm suggesting, since this instantiation logic won't be right for every model.

Suggested change
fantasy_model = SingleTaskGP(
fantasy_model = cls(model)(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it'd be a great simplification (& removal of duplicate logic) if we can simply use model.fantasize(...) rather than defining a custom _get_noiseless_fantasy_model.

Copy link
Contributor Author

@71c 71c Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean to write model.__class__(...) or type(model)(...); cls(model) is not a thing in Python.
But anyway, how about just not making that change because the model is currently assumed to be a SingleTaskGP anyway, and if we can find a way to use fantasize for a wider variety of models later, then just do that.

train_X=model.train_inputs[0],
train_Y=model.train_targets.unsqueeze(-1),
train_Yvar=model.likelihood.noise_covar.noise.unsqueeze(-1),
covar_module=deepcopy(model.covar_module),
mean_module=deepcopy(model.mean_module),
)

Yvar = torch.full_like(Y_fantasized, 1e-7)

# Set the outcome and input transforms of the fantasy model.
# The transforms should already be in eval mode but just set them to be sure
outcome_transform = getattr(model, "outcome_transform", None)
if outcome_transform is not None:
outcome_transform = deepcopy(outcome_transform).eval()
fantasy_model.outcome_transform = outcome_transform
# Need to transform the outcome just as in the SingleTaskGP constructor.
# Need to unsqueeze for BoTorch and then squeeze again for GPyTorch.
# Not transforming Yvar because 1e-7 is already close to 0 and it is a
# relative, not absolute, value.
Y_fantasized, _ = outcome_transform(
Y_fantasized.unsqueeze(-1), Yvar.unsqueeze(-1)
)
Y_fantasized = Y_fantasized.squeeze(-1)
input_transform = getattr(model, "input_transform", None)
if input_transform is not None:
fantasy_model.input_transform = deepcopy(input_transform).eval()

# update training inputs/targets to be batch mode fantasies
fantasy_model.set_train_data(
inputs=batch_X_observed, targets=Y_fantasized, strict=False
)
# use noiseless fantasies
fantasy_model.likelihood.noise_covar.noise = torch.full_like(Y_fantasized, 1e-7)
# load hyperparameters from original model
state_dict = deepcopy(model.state_dict())
fantasy_model.load_state_dict(state_dict)
fantasy_model.likelihood.noise_covar.noise = Yvar

return fantasy_model


Expand Down
119 changes: 105 additions & 14 deletions test/acquisition/test_analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import itertools
import math
from warnings import catch_warnings, simplefilter

Expand Down Expand Up @@ -34,10 +35,15 @@
from botorch.exceptions import UnsupportedError
from botorch.exceptions.warnings import NumericsWarning
from botorch.models import SingleTaskGP
from botorch.models.transforms import ChainedOutcomeTransform, Normalize, Standardize
from botorch.posteriors import GPyTorchPosterior
from botorch.sampling.pathwise.utils import get_train_inputs
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood
from gpytorch.module import Module
from gpytorch.priors.torch_priors import GammaPrior


NEI_NOISE = [
Expand Down Expand Up @@ -831,7 +837,15 @@ def _test_constrained_expected_improvement_batch(self, dtype: torch.dtype) -> No


class TestNoisyExpectedImprovement(BotorchTestCase):
def _get_model(self, dtype=torch.float):
def _get_model(
self,
dtype=torch.float,
outcome_transform=None,
input_transform=None,
low_x=0.0,
hi_x=1.0,
covar_module=None,
):
state_dict = {
"mean_module.raw_constant": torch.tensor([-0.0066]),
"covar_module.raw_outputscale": torch.tensor(1.0143),
Expand All @@ -843,15 +857,26 @@ def _get_model(self, dtype=torch.float):
"covar_module.outputscale_prior.concentration": torch.tensor(2.0),
"covar_module.outputscale_prior.rate": torch.tensor(0.1500),
}
train_x = torch.linspace(0, 1, 10, device=self.device, dtype=dtype).unsqueeze(
-1
)
train_x = torch.linspace(
0.0, 1.0, 10, device=self.device, dtype=dtype
).unsqueeze(-1)
# Taking the sin of the *transformed* input to make the test equivalent
# to when there are no input transforms
train_y = torch.sin(train_x * (2 * math.pi))
# Now transform the input to be passed into SingleTaskGP constructor
train_x = train_x * (hi_x - low_x) + low_x
noise = torch.tensor(NEI_NOISE, device=self.device, dtype=dtype)
train_y += noise
train_yvar = torch.full_like(train_y, 0.25**2)
model = SingleTaskGP(train_X=train_x, train_Y=train_y, train_Yvar=train_yvar)
model.load_state_dict(state_dict)
model = SingleTaskGP(
train_X=train_x,
train_Y=train_y,
train_Yvar=train_yvar,
outcome_transform=outcome_transform,
input_transform=input_transform,
covar_module=covar_module,
)
model.load_state_dict(state_dict, strict=False)
model.to(train_x)
model.eval()
return model
Expand All @@ -865,14 +890,75 @@ def test_noisy_expected_improvement(self):
):
NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)

for dtype in (torch.float, torch.double):
# Same as the default Matern kernel
# botorch.models.utils.gpytorch_modules.get_matern_kernel_with_gamma_prior,
# except RBFKernel is used instead of MaternKernel.
# For some reason, RBF gives numerical problems with torch.float but
# Matern does not. Therefore, we'll skip the test for RBF when dtype is
# torch.float.
covar_module_2 = ScaleKernel(
base_kernel=RBFKernel(
ard_num_dims=1,
batch_shape=torch.Size(),
lengthscale_prior=GammaPrior(3.0, 6.0),
),
batch_shape=torch.Size(),
outputscale_prior=GammaPrior(2.0, 0.15),
)
for dtype, use_octf, use_intf, bounds, covar_module in itertools.product(
(torch.float, torch.double),
(False, True),
(False, True),
(torch.tensor([[-3.4], [0.8]]), torch.tensor([[0.0], [1.0]])),
(None, covar_module_2),
):
with catch_warnings():
simplefilter("ignore", category=NumericsWarning)
self._test_noisy_expected_imrpovement(dtype)
self._test_noisy_expected_improvement(
dtype=dtype,
use_octf=use_octf,
use_intf=use_intf,
bounds=bounds,
covar_module=covar_module,
)

def _test_noisy_expected_improvement(
self,
dtype: torch.dtype,
use_octf: bool,
use_intf: bool,
bounds: torch.Tensor,
covar_module: Module,
) -> None:
if covar_module is not None and dtype == torch.float:
# Skip this test because RBF runs into numerical problems with float
# precision
return
octf = (
ChainedOutcomeTransform(standardize=Standardize(m=1)) if use_octf else None
)
intf = (
Normalize(
d=1,
bounds=bounds.to(device=self.device, dtype=dtype),
transform_on_train=True,
)
if use_intf
else None
)
low_x = bounds[0].item() if use_intf else 0.0
hi_x = bounds[1].item() if use_intf else 1.0
model = self._get_model(
dtype=dtype,
outcome_transform=octf,
input_transform=intf,
low_x=low_x,
hi_x=hi_x,
covar_module=covar_module,
)
# Make sure to get the non-transformed training inputs.
X_observed = get_train_inputs(model, transformed=False)[0]

def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
model = self._get_model(dtype=dtype)
X_observed = model.train_inputs[0]
nfan = 5
nEI = NoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
LogNEI = LogNoisyExpectedImprovement(model, X_observed, num_fantasies=nfan)
Expand All @@ -881,6 +967,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
self.assertTrue(hasattr(LogNEI, "best_f"))
self.assertIsInstance(LogNEI.model, SingleTaskGP)
self.assertIsInstance(LogNEI.model.likelihood, FixedNoiseGaussianLikelihood)
# Make sure _get_noiseless_fantasy_model gives them
# the same state_dict
self.assertEqual(LogNEI.model.state_dict(), model.state_dict())

LogNEI.model = nEI.model # let the two share their values and fantasies
LogNEI.best_f = nEI.best_f

Expand All @@ -892,9 +982,10 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
X_test_log = X_test.clone()
X_test.requires_grad = True
X_test_log.requires_grad = True
val = nEI(X_test)

val = nEI(X_test * (hi_x - low_x) + low_x)
# testing logNEI yields the same result (also checks dtype)
log_val = LogNEI(X_test_log)
log_val = LogNEI(X_test_log * (hi_x - low_x) + low_x)
exp_log_val = log_val.exp()
# notably, val[1] is usually zero in this test, which is precisely what
# gives rise to problems during optimization, and what logNEI avoids
Expand All @@ -916,7 +1007,7 @@ def _test_noisy_expected_imrpovement(self, dtype: torch.dtype) -> None:
# testing gradient through exp of log computation
exp_log_val.sum().backward()
# testing that first gradient element coincides. The second is in the
# regime where the naive implementation looses accuracy.
# regime where the naive implementation loses accuracy.
atol = 2e-5 if dtype == torch.float32 else 1e-12
rtol = atol
self.assertAllClose(X_test.grad[0], X_test_log.grad[0], atol=atol, rtol=rtol)
Expand Down