Closed
Description
🐛 Bug
When the qMultiFidelityLowerBoundMaxValueEntropy
acquisition function is constructed with a number of points that are pending evaluation (X_pending
) and with the SingleTaskMultiFidelityGP
model, the optimize_acqf
method fails. This does not happen when the qMultiFidelityMaxValueEntropy
acquisition function is used.
To reproduce
** Code snippet to reproduce **
import torch
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
from botorch.acquisition.max_value_entropy_search import (
qMultiFidelityLowerBoundMaxValueEntropy,
qMultiFidelityMaxValueEntropy
)
from botorch.acquisition.utils import project_to_target_fidelity
from botorch.fit import fit_gpytorch_model
from botorch.models.cost import AffineFidelityCostModel
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.transforms.outcome import Standardize
from botorch.optim.optimize import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
# create training data
train_x = torch.rand(20, 4)
train_y = (torch.sin(train_x[:, 0]) + torch.cos(train_x[:, 1])).unsqueeze(-1)
# create model
model = SingleTaskMultiFidelityGP(
train_X=train_x,
train_Y=train_y,
nu=2.5,
linear_truncated=True,
outcome_transform=Standardize(m=1),
data_fidelities=[2, 3],
)
# fit the model
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_model(mll)
# define function for projecting to the highest fidelity
def project(X):
return project_to_target_fidelity(X=X, target_fidelities={2: 1.0, 3: 1.0})
# create cost aware utility
cost_utility = InverseCostWeightedUtility(
cost_model=AffineFidelityCostModel(
fidelity_weights={2: 1.0, 3: 1.0},
)
)
for mve_af in [qMultiFidelityMaxValueEntropy, qMultiFidelityLowerBoundMaxValueEntropy]:
for pending_points in [None, torch.rand(1, 4)]:
af = mve_af(
model=model,
candidate_set=torch.rand(20, 4),
cost_aware_utility=cost_utility,
project=project,
X_pending=pending_points,
)
bounds = torch.tensor([[0.0, 0.0, 0.0, 0.0], [1.0, 1.0, 1.0, 1.0]])
candidate, acq_value = optimize_acqf(
acq_function=af,
bounds=bounds,
q=1,
num_restarts=20,
raw_samples=100,
)
print(candidate)
print(acq_value)
** Stack trace/error message **
Traceback (most recent call last):
File "/home/alexander/workspace/ingenious-framework/IngeniousFrame/ParameterTuning/MFBO/test2.py", line 57, in <module>
candidate, acq_value = optimize_acqf(
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
return _optimize_acqf(opt_acqf_inputs)
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
return _optimize_acqf_batch(opt_inputs=opt_inputs)
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
batch_initial_conditions = opt_inputs.get_ic_generator()(
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
Y_rnd_curr = acq_function(
File "/home/alexander/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/alexander/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/utils/transforms.py", line 259, in decorated
output = method(acqf, X, *args, **kwargs)
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/acquisition/max_value_entropy_search.py", line 808, in forward
ig = self._compute_information_gain(
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/acquisition/max_value_entropy_search.py", line 860, in _compute_information_gain
return qLowerBoundMaxValueEntropy._compute_information_gain(
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/acquisition/max_value_entropy_search.py", line 564, in _compute_information_gain
posterior_m = self.model.posterior(
File "/home/alexander/.local/lib/python3.10/site-packages/botorch/models/gpytorch.py", line 383, in posterior
mvn = self(X)
File "/home/alexander/.local/lib/python3.10/site-packages/gpytorch/models/exact_gp.py", line 310, in __call__
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
File "/home/alexander/.local/lib/python3.10/site-packages/torch/functional.py", line 124, in broadcast_shapes
raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
RuntimeError: Shape mismatch: objects cannot be broadcast to a single shape
Expected Behavior
Considering that optimze_acqf
works and correctly returns a candidate and its acquisition value for the qMultiFidelityMaxValueEntropy
acquisition function, I expected the same would be true for the qMultiFidelityLowerBoundMaxValueEntropy
acquisition function
System information
Please complete the following information:
- BoTorch Version
0.9.5
- GPyTorch Version
1.11
- PyTorch Version
2.1.2+cu121
- Ubuntu 22.04.3 LTS