-
| Hi all, I'm doing multi-objective, multi-fidelity bayesian optimization in botorch. I’m using qNEHVI as the acquisition function, and MultiTaskGP as the surrogate model. I'm tuning six parameters, and have two objectives. My high fidelity dataset has four observations, my low fidelity dataset has around 500. I'm generating four new candidates. My problem is that the four candidates that are being generated, are quite variable. Is this due to the sparse high fidelity data? Below you can find my code:  | 
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 10 replies
-
| Hi Evan,
With only 4 observations you are going to have a good deal of uncertainty
in your model, and if you are doing multi objective optimization in a 6d
space, there are likely to be many possible solutions that increase your
hyper volume, so I wouldn’t expect the optimizer to consistently pick the
same set of promising points.
Section 6 of
https://jmlr.org/papers/volume20/18-225/18-225.pdf provides some intuitions
about how the MTGP (ICM) model will behave, depending on the correlation
between the tasks.
Have you tried plotting the values at high fidelity against the values at
low fidelity? You might need a few more than 4 data points to get a sense
of how correlated your high fidelity and low fidelity tasks are, but if you
are able to perform those evaluations, you may quickly find out whether the
AF is identifying something useful.… On Thu, Apr 24, 2025 at 1:08 PM EvanClaes ***@***.***> wrote:
 Hi all,
 I'm doing multi-objective, multi-fidelity bayesian optimization in
 botorch. I’m using qNEHVI as the acquisition function, and MultiTaskGP as
 the surrogate model. I'm tuning six parameters, and have two objectives. My
 high fidelity dataset has four observations, my low fidelity dataset has
 around 500. I'm generating four new candidates. My problem is that the four
 candidates that are being generated, are quite variable.
 Is this due to the sparse high fidelity data?
 How can I diagnose the problem, and what would be potential solutions?
 Below you can find my code:
 import pandas as pd
 import numpy as np
 import torch
 import os
 import matplotlib.pyplot as plt
 from gpytorch.mlls.sum_marginal_log_likelihood import ExactMarginalLogLikelihood
 from botorch import fit_gpytorch_mll
 from botorch.utils.transforms import unnormalize, normalize
 from botorch.optim import optimize_acqf
 from scipy.optimize import minimize
 from botorch.sampling.normal import SobolQMCNormalSampler
 from botorch.models.transforms.outcome import Standardize
 from botorch.models.model_list_gp_regression import ModelListGP
 from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
 from botorch.models import MultiTaskGP
 from botorch.utils.multi_objective.box_decompositions.dominated import DominatedPartitioning
 from botorch.acquisition.multi_objective.logei import qLogNoisyExpectedHypervolumeImprovement
 def initialize_model(train_x, train_obj, bounds, train_noise):
     train_x_norm = normalize(train_x, bounds)
     models = []
     for i in range(train_obj.shape[-1]):
         train_y = train_obj[..., i : i + 1]
         train_yvar = train_noise[..., i : i + 1]
         models.append(
             MultiTaskGP(
                 train_x_norm, train_y, task_feature=-1,  output_tasks = [1], outcome_transform=Standardize(m=1)
             )
         )
     model = ModelListGP(*models)
     mll = SumMarginalLogLikelihood(model.likelihood, model)
     return mll, model
 tkwargs = {
     "dtype": torch.double,
     "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
 }
 SMOKE_TEST = os.environ.get("SMOKE_TEST")
 BATCH_SIZE = 4
 NUM_RESTARTS = 50
 RAW_SAMPLES = 4096
 MC_SAMPLES = 128
 #import data
 highFdata = pd.read_excel('ANT-0434 - DMEM 766.xlsx')
 lowFdata = pd.read_excel('LFdata.xlsx')
 #define bounds, noise levels and reference point
 actualBoundsMT = torch.tensor([[10,24,0.5,1,700,200,0],[60,192,6,6,2500,500,1]], dtype=torch.float64)
 qnehviBounds = torch.tensor([[0,0,0,0,0,0,0],[1,1,1,1,1,1,1]], dtype=torch.float64)
 NOISE_SE_highF = torch.tensor([0.56/np.sqrt(6), 2.86/np.sqrt(6)], **tkwargs)
 NOISE_SE_lowF = torch.tensor([0, 0], **tkwargs)
 refPoint = torch.tensor([0,0], dtype=torch.float64)
 #define constraint on two of the input variables
 inequality_constraint = [ #(800,200) as bottom left point, instead of (700,200)
 (torch.tensor([4, 5], dtype=torch.long),  # Parameter indices
 torch.tensor([(18/5), -1.0], dtype=torch.double),  # Coefficients
 (1/5))  # Right-hand side. This is >= by default
 ]
 #make the training data
 train_x1 = torch.tensor(highFdata.iloc[:,1:7].values)
 train_obj1 = torch.tensor(highFdata.iloc[:,7:9].values)
 train_x2 = torch.tensor(lowFdata.iloc[:,0:6].values)
 train_obj2 = torch.tensor(lowFdata.iloc[:,6:8].values)
 train_x1 = torch.cat([train_x1, torch.ones(train_x1.shape[0], 1) ], dim=1)
 train_x2 = torch.cat([train_x2, torch.zeros(train_x2.shape[0], 1) ], dim=1)
 train_x = torch.cat([train_x1, train_x2], dim=0)
 train_obj = torch.cat([train_obj1, train_obj2])
 train_noise = torch.cat([NOISE_SE_highF.repeat(highFdata.shape[0],1)**2,NOISE_SE_lowF.repeat(lowFdata.shape[0],1)**2])
 #fit model
 mll, model = initialize_model(train_x, train_obj, actualBoundsMT, train_noise)
 fit_gpytorch_mll(mll)
 #compute hypervolume
 bd = DominatedPartitioning(ref_point=refPoint, Y=train_obj)
 volume = bd.compute_hypervolume().item()
 qnehvi_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([128]))
 #partition non-dominated space into disjoint rectangles
 acq_func = qLogNoisyExpectedHypervolumeImprovement(
     model=model,
     ref_point=refPoint,
     X_baseline=normalize(train_x1, actualBoundsMT),
     prune_baseline=True,
     sampler=qnehvi_sampler,
 )
 #optimize
 candidates, _ = optimize_acqf(
     acq_function=acq_func,
     bounds=qnehviBounds,
     q=BATCH_SIZE,
     num_restarts=NUM_RESTARTS,
     raw_samples=RAW_SAMPLES,
     options={"batch_limit": 5, "maxiter": 200},
     sequential=True,
     fixed_features = {6: 1},
     inequality_constraints=inequality_constraint,
 )
 print(candidates)
 HFdata.xlsx
 <https://github.com/user-attachments/files/19896203/HFdata.xlsx>
 LFdata.xlsx
 <https://github.com/user-attachments/files/19896204/LFdata.xlsx>
 —
 Reply to this email directly, view it on GitHub
 <#2833>, or unsubscribe
 <https://github.com/notifications/unsubscribe-auth/AAAW34NVWNMOHQOCZ3X7B3D23ELA5AVCNFSM6AAAAAB3ZSIIE2VHI2DSMVQWIX3LMV43ERDJONRXK43TNFXW4OZYGIZTSNZZGI>
 .
 You are receiving this because you are subscribed to this thread.Message
 ID: ***@***.***>
 | 
Beta Was this translation helpful? Give feedback.
-
| Yes, I would recommend proceeding.
With only 4 points it's hard to tell if the correlation is 0.99, but so far
the correlation is quite strong.
If the tasks are highly correlated, then any point on the PF in the low
fidelity will also be on the PF for the high fidelity, so you can also
consider just fitting a single-task GP to the low-fidelity and generating
candidates by applying BO to that.
Looking briefly at your code, it looks like you aren't normalizing your
y's, which could be a big problem.  You may also wish to set the reference
point to some minimum values you care about (perhaps there is some minimum
cost or yield).
If you are more of an end-user of BO than a BO researcher, you may want to
consider using Ax instead. We don't have any tutorials for MF-MOBO in Ax
quite yet, but there is this tutorial which could be pretty handy for your
setup. https://ax.dev/docs/tutorials/multi_task/
… On Fri, Apr 25, 2025 at 11:07 AM EvanClaes ***@***.***> wrote:
 Hello Eytan,
 Thanks for the feedback. the pearson correlations between the 4 high
 fidelity observations and the 4 low fidelity observations (with the same
 input values) are very high, around 0,99. Your results seem to indicate
 that a multi-fidelity approach should be beneficial in this case. I do have
 to add that, in the objective space, the points are not very well spaced. I
 have two towards the lower end, and two towards the higher end...
 If my approach is correct and valid, then I guess we should just proceed
 with one particular set of generated candidates. Do you think it makes
 sense here to further increase the number of low-fidelity samples, or some
 of the optimizer parameters, before we generate these?
 Enjoy your weekend!
 —
 Reply to this email directly, view it on GitHub
 <#2833 (reply in thread)>,
 or unsubscribe
 <https://github.com/notifications/unsubscribe-auth/AAAW34MJZPIZXD24CXPZKM323JFS5AVCNFSM6AAAAAB3ZSIIE2VHI2DSMVQWIX3LMV43URDJONRXK43TNFXW4Q3PNVWWK3TUHMYTEOJUHA2TONY>
 .
 You are receiving this because you commented.Message ID:
 ***@***.***>
 | 
Beta Was this translation helpful? Give feedback.

Yes. Unfortunately that's really the only way to know for sure whether the LF data can serve as an effective proxy.
Indeed.
Give…