- 
                Notifications
    You must be signed in to change notification settings 
- Fork 450
Description
What happened?
I've recently been testing a few things for multi-output modeling and stumbled over some very weird (unexpected?) behavior regarding the sequential flag of optimize_acqf:
- Even for my very simple toy problem below, I get significantly different results when toggling the flag.
- The runtime difference is tremendous! For my example, sequential=Truetakes roughly 6s whereassequential=Falseruns for about 130s.
Here the corresponding plots:
sequential=True
 
sequential=False
 
What is interesting to note here
The achieved acquisition values of the batches (shown in the legend) are roughly identical for both settings, so the two optimization strategies seem to have ended up in two different but equivalent (in terms of function value) local minima. From a pure acqf perspective, this suggests that both solutions are equally good, even though sequential=True clearly gives the better qualitative result. Perhaps you can comment on this?
Also, I have no good explanation for the runtime difference. Is this expected? If so, is there a reason why sequential=False is the default?
Please provide a minimal, reproducible example of the unexpected behavior.
from time import perf_counter
import gpytorch
import numpy as np
import torch
from botorch.acquisition import qLogExpectedImprovement
from botorch.acquisition.multi_objective import qLogNoisyExpectedHypervolumeImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.models.utils.gpytorch_modules import (
    get_gaussian_likelihood_with_gamma_prior,
    get_matern_kernel_with_gamma_prior,
)
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from matplotlib import pyplot as plt
torch.manual_seed(1337)
torch.set_default_dtype(torch.float64)
########################################################################################
SEQUENTIAL = False  # <-- switch this to True
########################################################################################
BATCH_SIZE = 10
N_TRAINING_DATA = 100
N_GRID_POINTS = 100
CENTER_Y0 = torch.tensor([-0.5, -0.5])
CENTER_Y1 = torch.tensor([0.5, 0.5])
def fun(x: torch.Tensor) -> torch.Tensor:
    y0 = -(x - CENTER_Y0).pow(2).sum(dim=1)
    y1 = -(x - CENTER_Y1).pow(2).sum(dim=1)
    return torch.stack([y0, y1], dim=1)
def recommend(train_X, train_Y):
    mean_module = gpytorch.means.ConstantMean()
    covar_module = get_matern_kernel_with_gamma_prior(2)
    likelihood = get_gaussian_likelihood_with_gamma_prior()
    model = SingleTaskGP(
        train_X,
        train_Y,
        input_transform=Normalize(d=2),
        outcome_transform=Standardize(m=train_Y.shape[1]),
        mean_module=mean_module,
        covar_module=covar_module,
        likelihood=likelihood,
    )
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_mll(mll)
    if train_Y.shape[1] == 1:
        acqf = qLogExpectedImprovement(model, train_Y.max())
    else:
        acqf = qLogNoisyExpectedHypervolumeImprovement(
            model, ref_point=train_Y.min(dim=0)[0], X_baseline=train_X
        )
    bounds = torch.tensor([[-1.0, 1.0], [-1.0, 1.0]]).T
    rec, _ = optimize_acqf(
        acqf, bounds, BATCH_SIZE, num_restarts=5, raw_samples=20, sequential=SEQUENTIAL
    )
    return rec, acqf(rec).item()
train_X = torch.rand([N_TRAINING_DATA, 2]) * 2 - 1
train_Y = fun(train_X)
t = perf_counter()
rec_y0, val_y0 = recommend(train_X, train_Y[:, :1])
rec_y1, val_y1 = recommend(train_X, train_Y[:, 1:])
rec_p, val_p = recommend(train_X, train_Y)
print(perf_counter() - t)
out_y0 = fun(rec_y0)
out_y1 = fun(rec_y1)
out_p = fun(rec_p)
x0_mesh, x1_mesh = torch.meshgrid(
    torch.linspace(-1.0, 1.0, N_GRID_POINTS),
    torch.linspace(-1.0, 1.0, N_GRID_POINTS),
)
y = fun(torch.stack([x0_mesh.ravel(), x1_mesh.ravel()], dim=1))
y0_mesh = torch.reshape(y[:, 0], x0_mesh.shape)
y1_mesh = torch.reshape(y[:, 1], x1_mesh.shape)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
plt.sca(axs[0])
plt.contour(x0_mesh, x1_mesh, y0_mesh, colors="tab:red", alpha=0.2)
plt.contour(x0_mesh, x1_mesh, y1_mesh, colors="tab:blue", alpha=0.2)
plt.plot(*np.c_[CENTER_Y0, CENTER_Y1], "k", label="frontier")
plt.plot(train_X[:, 0], train_X[:, 1], "o", color="0.7", markersize=2, label="training")
plt.plot(
    rec_y0[:, 0], rec_y0[:, 1], "o", color="tab:red", label=f"single_y0: {val_y0:.3f}"
)
plt.plot(
    rec_y1[:, 0], rec_y1[:, 1], "o", color="tab:blue", label=f"single_y1: {val_y1:.3f}"
)
plt.plot(
    rec_p[:, 0], rec_p[:, 1], "o", color="tab:purple", label=f"pareto: {val_p:.3f}"
)
plt.legend(loc="upper left")
plt.axis("square")
plt.axis([-1, 1, -1, 1])
plt.sca(axs[1])
frontier = fun(torch.from_numpy(np.linspace(CENTER_Y0, CENTER_Y1)))
plt.plot(*frontier.T, "k", label="frontier")
plt.plot(train_Y[:, 0], train_Y[:, 1], "o", color="0.7", markersize=2, label="training")
plt.plot(out_y0[:, 0], out_y0[:, 1], "o", color="tab:red", label="single_y0")
plt.plot(out_y1[:, 0], out_y1[:, 1], "o", color="tab:blue", label="single_y1")
plt.plot(out_p[:, 0], out_p[:, 1], "o", color="tab:purple", label="pareto")
plt.legend(loc="lower left")
plt.axis("square")
plt.tight_layout()
plt.show()Please paste any relevant traceback/logs produced by the example provided.
BoTorch Version
0.13.0
Python Version
3.10
Operating System
macOS
Code of Conduct
- I agree to follow BoTorch's Code of Conduct