Skip to content

[Bug] Equality constraint violated in optimization of ACQF #1227

@jduerholt

Description

@jduerholt

🐛 Bug

In some cases the optimizer within optimize_acqf spits out candidates that violate the defined equality constraints provided to optimize_acqf. I build an MWE for it. You can find it below. If one reduces the number of points provided to the acqf, the chances that the constraint is fulfilled increase. Perhaps you have to run the MWE several times to observe the mentioned behavior, because the constraint is sometimes fulfilled and sometimes not. I attach the example data (totally made up data) as an csv file.

To reproduce

import torch
from torch import tensor
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from botorch.optim import optimize_acqf
from botorch.models import ModelListGP
from botorch.acquisition.multi_objective import qNoisyExpectedHypervolumeImprovement
import pandas as pd

torch.manual_seed(1)

tkwargs = {'dtype': torch.float64, 'device': 'cpu'}

df_data = pd.read_csv("equality_constraint.csv")
#df_data = pd.read_csv("equality_constraint.csv").loc[:10]
train_X = torch.from_numpy(df_data[["a", "b", "c", "d"]].values).to(**tkwargs)
train_Y = torch.from_numpy(df_data[["alpha", "beta"]].values).to(**tkwargs)

lower = tensor([0.1,0.3,0.1,30.])
upper = tensor([0.6,0.7,0.7,70.])

bounds = torch.stack((lower, upper)).to(**tkwargs)

models = []
for i, feat in enumerate(["alpha", "beta"]):
    gp = SingleTaskGP(train_X, train_Y[:,i].unsqueeze(-1), input_transform=Normalize(d=4), outcome_transform=Standardize(m=1))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_model(mll)
    models.append(gp)
model = ModelListGP(*models)

acqf = qNoisyExpectedHypervolumeImprovement(model=model, ref_point=df_data[["alpha", "beta"]].min().values.tolist(), X_baseline=train_X)

candidate, acq_value = optimize_acqf(
    acqf, bounds=bounds, q=1, num_restarts=8, raw_samples=1024, equality_constraints=[(tensor([1, 2, 0]), tensor([1., 1., 1.]).to(**tkwargs), 1.0)])

print(candidate)
print(candidate[0,:3].sum())

** Stack trace/error message **

// Paste the bad output here!

Expected Behavior

The expected behavior would be that candidate[0,:3].sum() equals 1, but often it returns numbers larger than 1.

Additional context

This is the csv file with the data.
equality_constraint.csv

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions