Skip to content

Commit

Permalink
[tune] Fix hyperopt points to evaluate for nested lists (#18113)
Browse files Browse the repository at this point in the history
  • Loading branch information
krfricke authored Aug 26, 2021
1 parent 8acb469 commit 34cf5db
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 11 deletions.
5 changes: 5 additions & 0 deletions python/ray/tune/suggest/hyperopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE, \
UNDEFINED_METRIC_MODE, UNDEFINED_SEARCH_SPACE
from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
from ray.tune.utils import flatten_dict

try:
hyperopt_logger = logging.getLogger("hyperopt")
Expand Down Expand Up @@ -284,6 +285,10 @@ def suggest(self, trial_id: str) -> Optional[Dict]:

# Taken from HyperOpt.base.evaluate
config = hpo.base.spec_from_misc(new_trial["misc"])

# We have to flatten nested spaces here so parameter names match
config = flatten_dict(config, flatten_list=True)

ctrl = hpo.base.Ctrl(self._hpopt_trials, current_trial=new_trial)
memo = self.domain.memo_from_config(config)
hpo.utils.use_obj_for_literal_in_memo(self.domain.expr, ctrl,
Expand Down
25 changes: 25 additions & 0 deletions python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,31 @@ def testPointsToEvaluateHyperOpt(self):
from ray.tune.suggest.hyperopt import HyperOptSearch
return self._testPointsToEvaluate(HyperOptSearch, config)

def testPointsToEvaluateHyperOptNested(self):
space = {
"nested": [
tune.sample.Integer(0, 10),
tune.sample.Integer(0, 10),
],
"nosample": [4, 8]
}

points_to_evaluate = [{"nested": [2, 4], "nosample": [4, 8]}]

from ray.tune.suggest.hyperopt import HyperOptSearch
searcher = HyperOptSearch(
space=space,
metric="_",
mode="max",
points_to_evaluate=points_to_evaluate)
config = searcher.suggest(trial_id="0")

self.assertSequenceEqual(config["nested"],
points_to_evaluate[0]["nested"])

self.assertSequenceEqual(config["nosample"],
points_to_evaluate[0]["nosample"])

def testPointsToEvaluateNevergrad(self):
config = {
"metric": tune.sample.Categorical([1, 2, 3, 4]).uniform(),
Expand Down
35 changes: 26 additions & 9 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,32 +283,49 @@ def deep_update(original,
return original


def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
def flatten_dict(dt: Dict,
delimiter: str = "/",
prevent_delimiter: bool = False,
flatten_list: bool = False):
"""Flatten dict.
Output and input are of the same dict type.
Input dict remains the same after the operation.
"""

def _raise_delimiter_exception():
raise ValueError(
f"Found delimiter `{delimiter}` in key when trying to flatten "
f"array. Please avoid using the delimiter in your specification.")

dt = copy.copy(dt)
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
raise ValueError(
"Found delimiter `{}` in key when trying to flatten array."
"Please avoid using the delimiter in your specification.")
while any(isinstance(v, dict) for v in dt.values()):
_raise_delimiter_exception()

while_check = (dict, list) if flatten_list else dict

while any(isinstance(v, while_check) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
raise ValueError(
"Found delimiter `{}` in key when trying to "
"flatten array. Please avoid using the delimiter "
"in your specification.")
_raise_delimiter_exception()

add[delimiter.join([key, str(subkey)])] = v
remove.append(key)
elif flatten_list and isinstance(value, list):
for i, v in enumerate(value):
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
_raise_delimiter_exception()

add[delimiter.join([key, str(i)])] = v
remove.append(key)

dt.update(add)
for k in remove:
del dt[k]
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, obs_space: gym.spaces.Space,
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)

hiddens = model_config.get("fcnet_hiddens", []) + \
model_config.get("post_fcnet_hiddens", [])
hiddens = list(model_config.get("fcnet_hiddens", [])) + \
list(model_config.get("post_fcnet_hiddens", []))
activation = model_config.get("fcnet_activation")
if not model_config.get("fcnet_hiddens", []):
activation = model_config.get("post_fcnet_activation")
Expand Down

0 comments on commit 34cf5db

Please sign in to comment.