Skip to content

[Bug]: Drastically different behavior when toggling sequential #2750

@AdrianSosic

Description

@AdrianSosic

What happened?

I've recently been testing a few things for multi-output modeling and stumbled over some very weird (unexpected?) behavior regarding the sequential flag of optimize_acqf:

  • Even for my very simple toy problem below, I get significantly different results when toggling the flag.
  • The runtime difference is tremendous! For my example, sequential=True takes roughly 6s whereas sequential=False runs for about 130s.

Here the corresponding plots:

sequential=True

Image

sequential=False

Image

What is interesting to note here

The achieved acquisition values of the batches (shown in the legend) are roughly identical for both settings, so the two optimization strategies seem to have ended up in two different but equivalent (in terms of function value) local minima. From a pure acqf perspective, this suggests that both solutions are equally good, even though sequential=True clearly gives the better qualitative result. Perhaps you can comment on this?

Also, I have no good explanation for the runtime difference. Is this expected? If so, is there a reason why sequential=False is the default?

Please provide a minimal, reproducible example of the unexpected behavior.

from time import perf_counter

import gpytorch
import numpy as np
import torch
from botorch.acquisition import qLogExpectedImprovement
from botorch.acquisition.multi_objective import qLogNoisyExpectedHypervolumeImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.models.utils.gpytorch_modules import (
    get_gaussian_likelihood_with_gamma_prior,
    get_matern_kernel_with_gamma_prior,
)
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from matplotlib import pyplot as plt

torch.manual_seed(1337)
torch.set_default_dtype(torch.float64)

########################################################################################
SEQUENTIAL = False  # <-- switch this to True
########################################################################################

BATCH_SIZE = 10
N_TRAINING_DATA = 100
N_GRID_POINTS = 100
CENTER_Y0 = torch.tensor([-0.5, -0.5])
CENTER_Y1 = torch.tensor([0.5, 0.5])


def fun(x: torch.Tensor) -> torch.Tensor:
    y0 = -(x - CENTER_Y0).pow(2).sum(dim=1)
    y1 = -(x - CENTER_Y1).pow(2).sum(dim=1)
    return torch.stack([y0, y1], dim=1)


def recommend(train_X, train_Y):
    mean_module = gpytorch.means.ConstantMean()
    covar_module = get_matern_kernel_with_gamma_prior(2)
    likelihood = get_gaussian_likelihood_with_gamma_prior()
    model = SingleTaskGP(
        train_X,
        train_Y,
        input_transform=Normalize(d=2),
        outcome_transform=Standardize(m=train_Y.shape[1]),
        mean_module=mean_module,
        covar_module=covar_module,
        likelihood=likelihood,
    )
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    fit_gpytorch_mll(mll)
    if train_Y.shape[1] == 1:
        acqf = qLogExpectedImprovement(model, train_Y.max())
    else:
        acqf = qLogNoisyExpectedHypervolumeImprovement(
            model, ref_point=train_Y.min(dim=0)[0], X_baseline=train_X
        )
    bounds = torch.tensor([[-1.0, 1.0], [-1.0, 1.0]]).T
    rec, _ = optimize_acqf(
        acqf, bounds, BATCH_SIZE, num_restarts=5, raw_samples=20, sequential=SEQUENTIAL
    )
    return rec, acqf(rec).item()


train_X = torch.rand([N_TRAINING_DATA, 2]) * 2 - 1
train_Y = fun(train_X)

t = perf_counter()
rec_y0, val_y0 = recommend(train_X, train_Y[:, :1])
rec_y1, val_y1 = recommend(train_X, train_Y[:, 1:])
rec_p, val_p = recommend(train_X, train_Y)
print(perf_counter() - t)

out_y0 = fun(rec_y0)
out_y1 = fun(rec_y1)
out_p = fun(rec_p)


x0_mesh, x1_mesh = torch.meshgrid(
    torch.linspace(-1.0, 1.0, N_GRID_POINTS),
    torch.linspace(-1.0, 1.0, N_GRID_POINTS),
)
y = fun(torch.stack([x0_mesh.ravel(), x1_mesh.ravel()], dim=1))
y0_mesh = torch.reshape(y[:, 0], x0_mesh.shape)
y1_mesh = torch.reshape(y[:, 1], x1_mesh.shape)


fig, axs = plt.subplots(1, 2, figsize=(10, 5))

plt.sca(axs[0])
plt.contour(x0_mesh, x1_mesh, y0_mesh, colors="tab:red", alpha=0.2)
plt.contour(x0_mesh, x1_mesh, y1_mesh, colors="tab:blue", alpha=0.2)
plt.plot(*np.c_[CENTER_Y0, CENTER_Y1], "k", label="frontier")
plt.plot(train_X[:, 0], train_X[:, 1], "o", color="0.7", markersize=2, label="training")
plt.plot(
    rec_y0[:, 0], rec_y0[:, 1], "o", color="tab:red", label=f"single_y0: {val_y0:.3f}"
)
plt.plot(
    rec_y1[:, 0], rec_y1[:, 1], "o", color="tab:blue", label=f"single_y1: {val_y1:.3f}"
)
plt.plot(
    rec_p[:, 0], rec_p[:, 1], "o", color="tab:purple", label=f"pareto: {val_p:.3f}"
)
plt.legend(loc="upper left")
plt.axis("square")
plt.axis([-1, 1, -1, 1])


plt.sca(axs[1])
frontier = fun(torch.from_numpy(np.linspace(CENTER_Y0, CENTER_Y1)))
plt.plot(*frontier.T, "k", label="frontier")
plt.plot(train_Y[:, 0], train_Y[:, 1], "o", color="0.7", markersize=2, label="training")
plt.plot(out_y0[:, 0], out_y0[:, 1], "o", color="tab:red", label="single_y0")
plt.plot(out_y1[:, 0], out_y1[:, 1], "o", color="tab:blue", label="single_y1")
plt.plot(out_p[:, 0], out_p[:, 1], "o", color="tab:purple", label="pareto")
plt.legend(loc="lower left")
plt.axis("square")

plt.tight_layout()
plt.show()

Please paste any relevant traceback/logs produced by the example provided.

BoTorch Version

0.13.0

Python Version

3.10

Operating System

macOS

Code of Conduct

  • I agree to follow BoTorch's Code of Conduct

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions