Skip to content

Commit

Permalink
Support generic search spaces in Eagle Designer using ProblemAndTrial…
Browse files Browse the repository at this point in the history
…sScaler

PiperOrigin-RevId: 512639997
  • Loading branch information
belenkil authored and copybara-github committed Feb 27, 2023
1 parent fa5a41c commit 2888c0f
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 84 deletions.
86 changes: 50 additions & 36 deletions vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
from vizier._src.algorithms.designers.eagle_strategy import serialization
from vizier._src.algorithms.random import random_sample
from vizier.interfaces import serializable
from vizier.pyvizier import converters


EagleStrategyUtils = eagle_strategy_utils.EagleStrategyUtils
FireflyAlgorithmConfig = eagle_strategy_utils.FireflyAlgorithmConfig
Expand Down Expand Up @@ -103,6 +105,20 @@ def __init__(
Exception: if the problem statement includes condional search space,
mutli-objectives or safety metrics.
"""
# Problem statement validation.
if problem_statement.search_space.is_conditional:
raise ValueError(
"Eagle Strategy designer doesn't support conditional parameters."
)
if not problem_statement.is_single_objective:
raise ValueError(
"Eagle Strategy designer doesn't support multi-objectives."
)
if problem_statement.is_safety_metric:
raise ValueError(
"Eagle Strategy designer doesn't support safety metrics."
)

if seed is None:
# When a key is not provided it will be set based on the current time to
# ensure non-repeated behavior.
Expand All @@ -112,33 +128,22 @@ def __init__(
'Setting the seed to %s'),
str(seed),
)

self._scaler = converters.ProblemAndTrialsScaler(problem_statement)
self._problem = self._scaler.problem_statement
self._rng = np.random.default_rng(seed=seed)
self.problem = problem_statement
self.config = config or FireflyAlgorithmConfig()
self._utils = EagleStrategyUtils(problem_statement, self.config, self._rng)
self._config = config or FireflyAlgorithmConfig()
self._utils = EagleStrategyUtils(self._problem, self._config, self._rng)
self._firefly_pool = FireflyPool(
utils=self._utils, capacity=self._utils.compute_pool_capacity())

if problem_statement.search_space.is_conditional:
raise ValueError(
"Eagle Strategy designer doesn't support conditional parameters.")
if not problem_statement.is_single_objective:
raise ValueError(
"Eagle Strategy designer doesn't support multi-objectives.")
if problem_statement.is_safety_metric:
raise ValueError(
"Eagle Strategy designer doesn't support safety metrics.")
if not self._utils.is_linear_scale:
raise ValueError(
"Eagle Strategy designer doesn't support non-linear scales.")

logging.info(
('Eagle Strategy designer initialized. Pool capacity: %s. '
'Eagle config:\n%s\nProblem statement:\n%s'),
(
'Eagle Strategy designer initialized. Pool capacity: %s. '
'Eagle config:\n%s\nProblem statement:\n%s'
),
self._utils.compute_pool_capacity(),
json.dumps(attr.asdict(self.config), indent=2),
self.problem,
json.dumps(attr.asdict(self._config), indent=2),
self._problem,
)

def dump(self) -> vz.Metadata:
Expand Down Expand Up @@ -197,10 +202,11 @@ def load(self, metadata: vz.Metadata) -> None:
self._firefly_pool.size,
)

def suggest(self,
count: Optional[int] = None) -> Sequence[vz.TrialSuggestion]:
def suggest(self, count: int = 1) -> Sequence[vz.TrialSuggestion]:
"""Suggests trials."""
return [self._suggest_one() for _ in range(max(count, 1))]
scaled_suggestions = [self._suggest_one() for _ in range(count)]
# Unscale suggestion parameters to the original search space.
return self._scaler.unmap(scaled_suggestions)

def _suggest_one(self) -> vz.TrialSuggestion:
"""Generates a single suggestion based on the current pool of flies.
Expand All @@ -218,7 +224,8 @@ def _suggest_one(self) -> vz.TrialSuggestion:
# Pool is underpopulated. Generate a random trial parameters.
# (b/243518714): Use random policy/designer to generate parameters.
suggested_parameters = random_sample.sample_parameters(
self._rng, self.problem.search_space)
self._rng, self._problem.search_space
)
# Create a new parent fly id and assign it to the trial, this will be
# used during Update to match the trial to its parent fly in the pool.
parent_fly_id = self._firefly_pool.generate_new_fly_id()
Expand Down Expand Up @@ -262,15 +269,16 @@ def _mutate_fly(self, moving_fly: Firefly) -> None:
pull_weights = self._utils.compute_pull_weight_by_type(
other_fly.trial.parameters, mutated_parameters, is_other_fly_better)
# Apply the pulls from 'other_fly' on the moving fly's parameters.
for param_config in self.problem.search_space.parameters:
for param_config in self._problem.search_space.parameters:
pull_weight = pull_weights[param_config.type]
# Accentuate 'other_fly' pull using 'exploration_rate'.
if pull_weight > 0.5:
explore_pull_weight = (
self.config.explore_rate * pull_weight +
(1 - self.config.explore_rate) * 1.0)
self._config.explore_rate * pull_weight
+ (1 - self._config.explore_rate) * 1.0
)
else:
explore_pull_weight = self.config.explore_rate * pull_weight
explore_pull_weight = self._config.explore_rate * pull_weight
# Update the parameters using 'other_fly' and 'explore_pull_rate'.
mutated_parameters[param_config.name] = (
self._utils.combine_two_parameters(
Expand All @@ -291,7 +299,7 @@ def _perturb_fly(self, moving_fly: Firefly) -> None:
"""
suggested_parameters = moving_fly.trial.parameters
perturbations = self._utils.create_perturbations(moving_fly.perturbation)
for i, param_config in enumerate(self.problem.search_space.parameters):
for i, param_config in enumerate(self._problem.search_space.parameters):
perturbed_value = self._utils.perturb_parameter(
param_config,
suggested_parameters[param_config.name].value,
Expand All @@ -309,12 +317,17 @@ def update(
parent fly id. For trials that were added to the study externally we assign
a new parent fly id.
Args:
completed:
Trials passed to 'update' are in the unscaled/original search space, and
will be converted to the scaled search space, so that all other methods
in the designer deal with scaled trial values.
Arguments:
completed: Trials in the original search space.
all_active:
"""
del all_active
for trial in completed.trials:
trials = self._scaler.map(completed.trials)
for trial in trials:
# Replaces trial metric name with a canonical metric name, which makes the
# serialization and deserialization simpler.
trial = self._utils.standardize_trial_metric_name(trial)
Expand Down Expand Up @@ -416,8 +429,9 @@ def _penalize_parent_fly(self, parent_fly: Firefly, trial: vz.Trial) -> None:
if trial.parameters == parent_fly.trial.parameters:
# If the new trial is identical to the parent trial, it means that the
# fly is stuck, and so we increase its perturbation.
parent_fly.perturbation = min(parent_fly.perturbation * 10,
self.config.max_perturbation)
parent_fly.perturbation = min(
parent_fly.perturbation * 10, self._config.max_perturbation
)
logging.info(
('Penalize Parent Id: %s. Parameters are stuck. '
'New perturbation factor: %s'),
Expand All @@ -433,7 +447,7 @@ def _penalize_parent_fly(self, parent_fly: Firefly, trial: vz.Trial) -> None:
parent_fly.id_,
parent_fly.perturbation,
)
if parent_fly.perturbation < self.config.perturbation_lower_bound:
if parent_fly.perturbation < self._config.perturbation_lower_bound:
# If the perturbation factor is too low we attempt to eliminate the
# unsuccessful parent fly from the pool.
if self._firefly_pool.size == self._firefly_pool.capacity:
Expand Down
111 changes: 96 additions & 15 deletions vizier/_src/algorithms/designers/eagle_strategy/eagle_strategy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

"""Tests for eagle_strategy."""

import numpy as np
from vizier import algorithms as vza
from vizier import pyvizier as vz
from vizier._src.algorithms.designers.eagle_strategy import eagle_strategy
from vizier._src.algorithms.designers.eagle_strategy import testing

from absl.testing import absltest
from absl.testing import parameterized


EagleStrategyDesiger = eagle_strategy.EagleStrategyDesigner


Expand Down Expand Up @@ -53,7 +56,67 @@ def test_suggest_one(self):
trial_suggestion = eagle_designer._suggest_one()
self.assertIsInstance(trial_suggestion, vz.TrialSuggestion)
self.assertIsNotNone(
trial_suggestion.metadata.ns('eagle').get('parent_fly_id'))
trial_suggestion.metadata.ns('eagle').get('parent_fly_id')
)

def test_embedding(self):
eagle_designer = testing.create_fake_populated_eagle_designer()
# Check that the problem was converted.
self.assertEqual(
eagle_designer._problem.search_space.parameters[0].bounds, (0.0, 1.0)
)
# Check that internal suggestions are in normalized range.
for _ in range(10):
trial_suggestion = eagle_designer._suggest_one()
self.assertBetween(trial_suggestion.parameters['x'].value, 0.0, 1.0)
# Check that update maps the trials correctly to the normalized space.
eagle_designer = testing.create_fake_empty_eagle_designer()
trial = vz.Trial({'x': 10.0})
trial = trial.complete(
vz.Measurement(metrics={'objective': np.random.uniform()})
)
complete_trials = vza.CompletedTrials([trial])
eagle_designer.update(complete_trials, vza.ActiveTrials())
print(eagle_designer._firefly_pool)
self.assertEqual(
eagle_designer._firefly_pool._pool[0].trial.parameters['x'].value, 1.0
)

@parameterized.parameters(1e-4, 1.0)
def test_penalize_parent_fly_no_trial_change(self, perturbation):
eagle_designer = testing.create_fake_populated_eagle_designer(
x_values=[1.0, 2.0, 3.0], obj_values=[1, 2, 3], parent_fly_ids=[1, 2, 3]
)
trial = testing.create_fake_trial(parent_fly_id=2, x_value=2.0, obj_value=2)
parent_fly = eagle_designer._firefly_pool._pool[2]
# Set the perturbation.
parent_fly.perturbation = perturbation
before_perturbation = parent_fly.perturbation
eagle_designer._penalize_parent_fly(parent_fly, trial)
after_perturbation = parent_fly.perturbation
if perturbation == 1.0:
# Perturbation is already high so capped by the maximimum.
self.assertEqual(
after_perturbation,
before_perturbation * eagle_designer._config.max_perturbation,
)
elif perturbation == 1e-4:
# Not reaching the maximum yet, multiply by 10.
self.assertEqual(after_perturbation, 1e-3)

def test_penalize_parent_fly(self):
# Capacitated pool size has 11 fireflies.
eagle_designer = testing.create_fake_populated_eagle_designer(
x_values=[1.0, 2.0, 3.0], obj_values=[1, 2, 3], parent_fly_ids=[1, 2, 3]
)
trial = testing.create_fake_trial(
parent_fly_id=2, x_value=1.42, obj_value=0.5
)
parent_fly = eagle_designer._firefly_pool._pool[2]
before_perturbation = parent_fly.perturbation
eagle_designer._penalize_parent_fly(parent_fly, trial)
after_perturbation = parent_fly.perturbation
self.assertEqual(after_perturbation, before_perturbation * 0.9)

def test_suggest(self):
eagle_designer = testing.create_fake_populated_eagle_designer()
Expand Down Expand Up @@ -107,20 +170,38 @@ def test_update_empty_pool(self):
eagle_designer._update_one(trial)
self.assertIs(eagle_designer._firefly_pool._pool[0].trial, trial)

def test_linear_scale(self):
problem = vz.ProblemStatement(metric_information=[
vz.MetricInformation(name='obj', goal=vz.ObjectiveMetricGoal.MAXIMIZE)
])
problem.search_space.root.add_float_param('f1', 0.0, 10.0)
problem.search_space.root.add_float_param('f2', 0.0, 5.0)
problem.search_space.root.add_categorical_param('c1', ['a', 'b', 'c'])
problem.search_space.root.add_int_param('i1', 0, 10)
problem.search_space.root.add_discrete_param('d1', [1, 5, 10])
EagleStrategyDesiger(problem)
problem.search_space.root.add_float_param(
'f3', 0.0, 10.0, scale_type=vz.ScaleType.LOG)
with self.assertRaises(ValueError):
EagleStrategyDesiger(problem)
@parameterized.parameters(1, 3, 5)
def test_suggest_update(self, batch_size):
problem = vz.ProblemStatement()
problem.search_space.select_root().add_float_param('float', -5.0, 5.0)
problem.search_space.select_root().add_int_param(
'int', min_value=0, max_value=10
)
problem.search_space.select_root().add_discrete_param(
'discrete', feasible_values=[0.0, 0.6]
)
problem.search_space.select_root().add_categorical_param(
'categorical', feasible_values=['a', 'b', 'c']
)
problem.metric_information.append(
vz.MetricInformation(goal=vz.ObjectiveMetricGoal.MINIMIZE, name='')
)
eagle_designer = EagleStrategyDesiger(problem)

tid = 1
# Simulate running the designer for 3 suggestions each with a batch.
for _ in range(3):
suggestions = eagle_designer.suggest(batch_size)
completed = []
# Completing the suggestions while assigning unique trial id.
for suggestion in suggestions:
completed.append(
suggestion.to_trial(tid).complete(
vz.Measurement(metrics={'': np.random.uniform()})
)
)
tid += 1
eagle_designer.update(vza.CompletedTrials(completed), vza.ActiveTrials())


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ class FireflyAlgorithmConfig:
gravity: float = 1.0
negative_gravity: float = 0.02
# Visiblitiy
visibility: float = 1.0
visibility: float = 3.0
categorical_visibility: float = 0.2
discrete_visibility: float = 1.0
# Perturbation
perturbation: float = 1e-2
categorical_perturbation_factor: float = 25
pure_categorical_perturbation: float = 0.1
discrete_perturbation_factor: float = 10.0
perturbation: float = 1e-1
perturbation_lower_bound: float = 1e-3
categorical_perturbation_factor: float = 25.0
discrete_perturbation_factor: float = 10.0
pure_categorical_perturbation: float = 0.1
max_perturbation: float = 0.5
# Pool size
pool_size_factor: float = 1.2
Expand Down Expand Up @@ -100,7 +100,9 @@ def __attrs_post_init__(self):
self._search_space = self.problem_statement.search_space
self._n_parameters = len(self._search_space.parameters)
self._cache_degrees_of_freedom()
self._original_metric_name = self.problem_statement.single_objective_metric_name
self._original_metric_name = (
self.problem_statement.single_objective_metric_name
)
self._goal = self.problem_statement.metric_information.item().goal
logging.info('EagleStrategyUtils was created.\n%s', str(self))

Expand Down Expand Up @@ -204,8 +206,8 @@ def compute_cononical_distance(

def compute_pool_capacity(self) -> int:
"""Computes the pool capacity."""
return 10 + int(0.5 * self._n_parameters +
self._n_parameters**self.config.pool_size_factor)
df = self._n_parameters
return 10 + round((df**self.config.pool_size_factor + df) * 0.5)

def combine_two_parameters(
self,
Expand Down Expand Up @@ -363,15 +365,6 @@ def is_pure_categorical(self) -> bool:
for p in self._search_space.parameters
])

@property
def is_linear_scale(self) -> bool:
"""Returns whether all decimal parameters in search space has linear scale.
"""
return all([
p.scale_type is None or p.scale_type == vz.ScaleType.LINEAR
for p in self._search_space.parameters
])

def standardize_trial_metric_name(self, trial: vz.Trial) -> vz.Trial:
"""Creates a new trial with canonical metric name."""
value = trial.final_measurement.metrics[self._original_metric_name].value
Expand All @@ -387,7 +380,9 @@ def display_trial(self, trial: vz.Trial) -> str:
for k, v in trial.parameters.as_dict().items()
}
if trial.final_measurement:
obj_value = f'{list(trial.final_measurement.metrics.values())[0].value:.5f}'
obj_value = (
f'{list(trial.final_measurement.metrics.values())[0].value:.5f}'
)
return f'Value: {obj_value}, Parameters: {parameters}'
else:
return f'Parameters: {parameters}'
Expand Down Expand Up @@ -431,7 +426,7 @@ def get_shuffled_flies(self, rng: np.random.Generator) -> list[Firefly]:
return random_sample.shuffle_list(rng, list(self._pool.values()))

def generate_new_fly_id(self) -> int:
"""Generates a unique fly id to identify a fly in the pool."""
"""Generates a unique fly id (starts from 0) to identify a fly in the pool."""
self._max_fly_id += 1
logging.info('New fly id generated (%s).', self._max_fly_id - 1)
return self._max_fly_id - 1
Expand Down
Loading

0 comments on commit 2888c0f

Please sign in to comment.