Skip to content

Commit

Permalink
Compatibility with optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
fjwillemsen committed Oct 26, 2024
1 parent 5ab70df commit e407a84
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 58 deletions.
21 changes: 13 additions & 8 deletions kernel_tuner/strategies/bayes_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def predict_list(self, lst: list) -> Tuple[list, list, list]:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
mu, std = self.__model.predict(lst, return_std=True)
return mu, std
return list(zip(mu, std)), mu, std

def fit_observations_to_model(self):
"""Update the model based on the current list of observations."""
Expand Down Expand Up @@ -540,7 +540,7 @@ def initial_sample(self):
if self.is_valid(observation):
collected_samples += 1
self.fit_observations_to_model()
_, std = self.predict_list(self.unvisited_cache)
_, _, std = self.predict_list(self.unvisited_cache)
self.initial_sample_mean = np.mean(self.__valid_observations)
# Alternatively:
# self.initial_sample_std = np.std(self.__valid_observations)
Expand Down Expand Up @@ -736,11 +736,11 @@ def __optimize_multi_advanced(self, max_fevals, increase_precision=False):
if self.__visited_num >= self.searchspace_size or self.fevals >= max_fevals:
break
if increase_precision is True:
predictions, _, std = self.predict_list(self.unvisited_cache)
predictions = self.predict_list(self.unvisited_cache)
hyperparam = self.contextual_variance(std)
list_of_acquisition_values = af(predictions, hyperparam)
best_af = self.argopt(list_of_acquisition_values)
del predictions[best_af] # to avoid going out of bounds
# del predictions[best_af] # to avoid going out of bounds
candidate_params = self.unvisited_cache[best_af]
candidate_index = self.find_param_config_index(candidate_params)
observation = self.evaluate_objective_function(candidate_params)
Expand Down Expand Up @@ -855,13 +855,12 @@ def af_random(self, predictions=None, hyperparam=None) -> list:
def af_probability_of_improvement(self, predictions=None, hyperparam=None) -> list:
"""Acquisition function Probability of Improvement (PI)."""
# prefetch required data
x_mu, x_std = predictions
if hyperparam is None:
hyperparam = self.af_params["explorationfactor"]
fplus = self.current_optimum - hyperparam

# precompute difference of improvement
list_diff_improvement = list(-((fplus - x_mu) / (x_std + 1e-9)) for (x_mu, x_std) in predictions)
list_diff_improvement = list(-((fplus - x_mu) / (x_std + 1e-9)) for x_mu, x_std in predictions[0])

# compute probability of improvement with CDF in bulk
list_prob_improvement = norm.cdf(list_diff_improvement)
Expand All @@ -870,10 +869,15 @@ def af_probability_of_improvement(self, predictions=None, hyperparam=None) -> li
def af_expected_improvement(self, predictions=None, hyperparam=None) -> list:
"""Acquisition function Expected Improvement (EI)."""
# prefetch required data
x_mu, x_std = predictions
if hyperparam is None:
hyperparam = self.af_params["explorationfactor"]
fplus = self.current_optimum - hyperparam
if len(predictions) == 3:
predictions, x_mu, x_std = predictions
elif len(predictions) == 2:
x_mu, x_std = predictions
else:
raise ValueError(f"Invalid predictions size {len(predictions)}")

# precompute difference of improvement, CDF and PDF in bulk
list_diff_improvement = list((fplus - x_mu) / (x_std + 1e-9) for (x_mu, x_std) in predictions)
Expand All @@ -892,6 +896,7 @@ def af_lower_confidence_bound(self, predictions=None, hyperparam=None) -> list:
if hyperparam is None:
hyperparam = self.af_params["explorationfactor"]
beta = hyperparam
_, x_mu, x_std = predictions

# compute LCB in bulk
list_lower_confidence_bound = (x_mu - beta * x_std)
Expand All @@ -900,7 +905,7 @@ def af_lower_confidence_bound(self, predictions=None, hyperparam=None) -> list:
def af_lower_confidence_bound_srinivas(self, predictions=None, hyperparam=None) -> list:
"""Acquisition function Lower Confidence Bound (UCB-S) after Srinivas, 2010 / Brochu, 2010."""
# prefetch required data
x_mu, x_std = predictions
_, x_mu, x_std = predictions
if hyperparam is None:
hyperparam = self.af_params["explorationfactor"]

Expand Down
Loading

0 comments on commit e407a84

Please sign in to comment.