Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] Fix hyperopt points to evaluate for nested lists #18113

Merged
merged 2 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/ray/tune/suggest/hyperopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE, \
UNDEFINED_METRIC_MODE, UNDEFINED_SEARCH_SPACE
from ray.tune.suggest.variant_generator import assign_value, parse_spec_vars
from ray.tune.utils import flatten_dict

try:
hyperopt_logger = logging.getLogger("hyperopt")
Expand Down Expand Up @@ -284,6 +285,10 @@ def suggest(self, trial_id: str) -> Optional[Dict]:

# Taken from HyperOpt.base.evaluate
config = hpo.base.spec_from_misc(new_trial["misc"])

# We have to flatten nested spaces here so parameter names match
config = flatten_dict(config, flatten_list=True)

ctrl = hpo.base.Ctrl(self._hpopt_trials, current_trial=new_trial)
memo = self.domain.memo_from_config(config)
hpo.utils.use_obj_for_literal_in_memo(self.domain.expr, ctrl,
Expand Down
25 changes: 25 additions & 0 deletions python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,6 +1470,31 @@ def testPointsToEvaluateHyperOpt(self):
from ray.tune.suggest.hyperopt import HyperOptSearch
return self._testPointsToEvaluate(HyperOptSearch, config)

def testPointsToEvaluateHyperOptNested(self):
space = {
"nested": [
tune.sample.Integer(0, 10),
tune.sample.Integer(0, 10),
],
"nosample": [4, 8]
}

points_to_evaluate = [{"nested": [2, 4], "nosample": [4, 8]}]

from ray.tune.suggest.hyperopt import HyperOptSearch
searcher = HyperOptSearch(
space=space,
metric="_",
mode="max",
points_to_evaluate=points_to_evaluate)
config = searcher.suggest(trial_id="0")

self.assertSequenceEqual(config["nested"],
points_to_evaluate[0]["nested"])

self.assertSequenceEqual(config["nosample"],
points_to_evaluate[0]["nosample"])

def testPointsToEvaluateNevergrad(self):
config = {
"metric": tune.sample.Categorical([1, 2, 3, 4]).uniform(),
Expand Down
35 changes: 26 additions & 9 deletions python/ray/tune/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,32 +283,49 @@ def deep_update(original,
return original


def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
def flatten_dict(dt: Dict,
delimiter: str = "/",
prevent_delimiter: bool = False,
flatten_list: bool = False):
"""Flatten dict.

Output and input are of the same dict type.
Input dict remains the same after the operation.
"""

def _raise_delimiter_exception():
raise ValueError(
f"Found delimiter `{delimiter}` in key when trying to flatten "
f"array. Please avoid using the delimiter in your specification.")

dt = copy.copy(dt)
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
raise ValueError(
"Found delimiter `{}` in key when trying to flatten array."
"Please avoid using the delimiter in your specification.")
while any(isinstance(v, dict) for v in dt.values()):
_raise_delimiter_exception()

while_check = (dict, list) if flatten_list else dict

while any(isinstance(v, while_check) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
raise ValueError(
"Found delimiter `{}` in key when trying to "
"flatten array. Please avoid using the delimiter "
"in your specification.")
_raise_delimiter_exception()

add[delimiter.join([key, str(subkey)])] = v
remove.append(key)
elif flatten_list and isinstance(value, list):
for i, v in enumerate(value):
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
_raise_delimiter_exception()

add[delimiter.join([key, str(i)])] = v
remove.append(key)

dt.update(add)
for k in remove:
del dt[k]
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(self, obs_space: gym.spaces.Space,
super(FullyConnectedNetwork, self).__init__(
obs_space, action_space, num_outputs, model_config, name)

hiddens = model_config.get("fcnet_hiddens", []) + \
model_config.get("post_fcnet_hiddens", [])
hiddens = list(model_config.get("fcnet_hiddens", [])) + \
list(model_config.get("post_fcnet_hiddens", []))
activation = model_config.get("fcnet_activation")
if not model_config.get("fcnet_hiddens", []):
activation = model_config.get("post_fcnet_activation")
Expand Down