Skip to content

[Bug] qMultiFidelityLowerBoundMaxValueEntropy fails when X_pending is not None #2183

Closed
@AlexanderMouton

Description

@AlexanderMouton

🐛 Bug

When the qMultiFidelityLowerBoundMaxValueEntropy acquisition function is constructed with a number of points that are pending evaluation (X_pending) and with the SingleTaskMultiFidelityGP model, the optimize_acqf method fails. This does not happen when the qMultiFidelityMaxValueEntropy acquisition function is used.

To reproduce

** Code snippet to reproduce **

import torch
from botorch.acquisition.cost_aware import InverseCostWeightedUtility

from botorch.acquisition.max_value_entropy_search import (
    qMultiFidelityLowerBoundMaxValueEntropy,
    qMultiFidelityMaxValueEntropy
)
from botorch.acquisition.utils import project_to_target_fidelity

from botorch.fit import fit_gpytorch_model
from botorch.models.cost import AffineFidelityCostModel
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.transforms.outcome import Standardize
from botorch.optim.optimize import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood

# create training data
train_x = torch.rand(20, 4)
train_y = (torch.sin(train_x[:, 0]) + torch.cos(train_x[:, 1])).unsqueeze(-1)

# create model
model = SingleTaskMultiFidelityGP(
    train_X=train_x,
    train_Y=train_y,
    nu=2.5,
    linear_truncated=True,
    outcome_transform=Standardize(m=1),
    data_fidelities=[2, 3],
)
# fit the model
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)

# define function for projecting to the highest fidelity
def project(X):
    return project_to_target_fidelity(X=X, target_fidelities={2: 1.0, 3: 1.0})

# create cost aware utility
cost_utility = InverseCostWeightedUtility(
    cost_model=AffineFidelityCostModel(
        fidelity_weights={2: 1.0, 3: 1.0},
    )
)

for mve_af in [qMultiFidelityMaxValueEntropy, qMultiFidelityLowerBoundMaxValueEntropy]:
    for pending_points in [None, torch.rand(1, 4)]:
        af = mve_af(
            model=model,
            candidate_set=torch.rand(20, 4),
            cost_aware_utility=cost_utility,
            project=project,
            X_pending=pending_points,
        )

        bounds = torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])

        candidate, acq_value = optimize_acqf(
            acq_function=af,
            bounds=bounds,
            q=1,
            num_restarts=20,
            raw_samples=100,
        )

        print(candidate)
        print(acq_value)

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/alexander/workspace/ingenious-framework/IngeniousFrame/ParameterTuning/MFBO/test2.py", line 57, in <module>
    candidate, acq_value = optimize_acqf(
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
    batch_initial_conditions = opt_inputs.get_ic_generator()(
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
    Y_rnd_curr = acq_function(
  File "/home/alexander/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alexander/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/utils/transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/acquisition/max_value_entropy_search.py", line 808, in forward
    ig = self._compute_information_gain(
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/acquisition/max_value_entropy_search.py", line 860, in _compute_information_gain
    return qLowerBoundMaxValueEntropy._compute_information_gain(
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/acquisition/max_value_entropy_search.py", line 564, in _compute_information_gain
    posterior_m = self.model.posterior(
  File "/home/alexander/.local/lib/python3.10/site-packages/botorch/models/gpytorch.py", line 383, in posterior
    mvn = self(X)
  File "/home/alexander/.local/lib/python3.10/site-packages/gpytorch/models/exact_gp.py", line 310, in __call__
    batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
  File "/home/alexander/.local/lib/python3.10/site-packages/torch/functional.py", line 124, in broadcast_shapes
    raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
RuntimeError: Shape mismatch: objects cannot be broadcast to a single shape

Expected Behavior

Considering that optimze_acqf works and correctly returns a candidate and its acquisition value for the qMultiFidelityMaxValueEntropy acquisition function, I expected the same would be true for the qMultiFidelityLowerBoundMaxValueEntropy acquisition function

System information

Please complete the following information:

  • BoTorch Version 0.9.5
  • GPyTorch Version 1.11
  • PyTorch Version 2.1.2+cu121
  • Ubuntu 22.04.3 LTS

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions