Skip to content

Commit

Permalink
Move the tune driver into a remote task (ray-project#13778)
Browse files Browse the repository at this point in the history
  • Loading branch information
ericl authored Feb 3, 2021
1 parent b4684cf commit d335ce2
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 20 deletions.
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ py_test(
tags = ["exclusive"],
)

py_test(
name = "test_remote",
size = "medium",
srcs = ["tests/test_remote.py"],
deps = [":tune_lib"],
tags = ["exclusive"],
)

py_test(
name = "test_sample",
size = "medium",
Expand Down
13 changes: 0 additions & 13 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,7 @@ class RayTrialExecutor(TrialExecutor):
def __init__(self,
queue_trials: bool = False,
reuse_actors: bool = False,
ray_auto_init: Optional[bool] = None,
refresh_period: Optional[float] = None):
if ray_auto_init is None:
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
ray_auto_init = False
else:
ray_auto_init = True

super(RayTrialExecutor, self).__init__(queue_trials)
# Check for if we are launching a trial without resources in kick off
# autoscaler.
Expand Down Expand Up @@ -193,11 +185,6 @@ def __init__(self,
self._last_ip_refresh = float("-inf")
self._last_ip_addresses = set()
self._last_nontrivial_wait = time.time()
if not ray.is_initialized() and ray_auto_init:
logger.info("Initializing Ray automatically."
"For cluster usage or custom Ray initialization, "
"call `ray.init(...)` before `tune.run`.")
ray.init()

if ray.is_initialized():
self._update_avail_resources()
Expand Down
77 changes: 77 additions & 0 deletions python/ray/tune/tests/test_remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest

import ray
from ray.tune import register_trainable, run_experiments, run
from ray.tune.result import TIMESTEPS_TOTAL
from ray.tune.experiment import Experiment
from ray.tune.trial import Trial
from ray.util.client.ray_client_helpers import ray_start_client_server


class RemoteTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()

def testRemoteRunExperiments(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)

register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
})
[trial] = run_experiments(exp1, _remote=True)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)

def testRemoteRun(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)

analysis = run(train, _remote=True)
[trial] = analysis.trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)

def testRemoteRunExperimentsInClient(self):
ray.init()
assert not ray.util.client.ray.is_connected()
with ray_start_client_server():
assert ray.util.client.ray.is_connected()

def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)

register_trainable("f1", train)
exp1 = Experiment(**{
"name": "foo",
"run": "f1",
})
[trial] = run_experiments(exp1)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)

def testRemoteRunInClient(self):
ray.init()
assert not ray.util.client.ray.is_connected()
with ray_start_client_server():
assert ray.util.client.ray.is_connected()

def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)

analysis = run(train)
[trial] = analysis.trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)


if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
2 changes: 2 additions & 0 deletions python/ray/tune/tests/test_trial_runner_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def num_checkpoints(trial):

@patch("ray.tune.syncer.CLOUD_SYNC_PERIOD", 0)
def testCheckpointAutoPeriod(self):
ray.init(num_cpus=3)

# This makes checkpointing take 2 seconds.
def sync_up(source, target):
time.sleep(2)
Expand Down
1 change: 1 addition & 0 deletions python/ray/tune/tests/test_trial_runner_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def get_next_failed_trial(self):

class TrialRunnerCallbacks(unittest.TestCase):
def setUp(self):
ray.init()
self.tmpdir = tempfile.mkdtemp()
self.callback = TestCallback()
self.executor = _MockTrialExecutor()
Expand Down
13 changes: 7 additions & 6 deletions python/ray/tune/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ class Trial:
"""

_nonjson_fields = [
"results",
"best_result",
"param_config",
"extra_arg",
]

PENDING = "PENDING"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
Expand Down Expand Up @@ -289,12 +296,6 @@ def __init__(self,
self.param_config = None
self.extra_arg = None

self._nonjson_fields = [
"results",
"best_result",
"param_config",
"extra_arg",
]
if trial_name_creator:
self.custom_trial_name = trial_name_creator(self)

Expand Down
103 changes: 102 additions & 1 deletion python/ray/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import time

import ray
from ray.tune.analysis import ExperimentAnalysis
from ray.tune.callback import Callback
from ray.tune.error import TuneError
Expand Down Expand Up @@ -111,6 +112,7 @@ def run(
sync_to_cloud: Optional = None,
sync_to_driver: Optional = None,
sync_on_checkpoint: Optional = None,
_remote: bool = None,
) -> ExperimentAnalysis:
"""Executes training.
Expand Down Expand Up @@ -270,13 +272,74 @@ def run(
``ray.tune.callback.Callback`` class. If not passed,
`LoggerCallback` and `SyncerCallback` callbacks are automatically
added.
_remote (bool): Whether to run the Tune driver in a remote function.
This is disabled automatically if a custom trial executor is
passed in. This is enabled by default in Ray client mode.
Returns:
ExperimentAnalysis: Object for experiment analysis.
Raises:
TuneError: Any trials failed and `raise_on_failed_trial` is True.
"""

if _remote is None:
_remote = ray.util.client.ray.is_connected()

if _remote is True and trial_executor:
raise ValueError("cannot use custom trial executor")

if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
_ray_auto_init()

if _remote:
return ray.get(
ray.remote(num_cpus=0)(run).remote(
run_or_experiment,
name,
metric,
mode,
stop,
time_budget_s,
config,
resources_per_trial,
num_samples,
local_dir,
search_alg,
scheduler,
keep_checkpoints_num,
checkpoint_score_attr,
checkpoint_freq,
checkpoint_at_end,
verbose,
progress_reporter,
log_to_file,
trial_name_creator,
trial_dirname_creator,
sync_config,
export_formats,
max_failures,
fail_fast,
restore,
server_port,
resume,
queue_trials,
reuse_actors,
trial_executor,
raise_on_failed_trial,
callbacks,
# Deprecated args
loggers,
ray_auto_init,
run_errored_only,
global_checkpoint_period,
with_server,
upload_dir,
sync_to_cloud,
sync_to_driver,
sync_on_checkpoint,
_remote=False))

all_start = time.time()
if global_checkpoint_period:
raise ValueError("global_checkpoint_period is deprecated. Set env var "
Expand Down Expand Up @@ -509,7 +572,8 @@ def run_experiments(
trial_executor: Optional[RayTrialExecutor] = None,
raise_on_failed_trial: bool = True,
concurrent: bool = True,
callbacks: Optional[Sequence[Callback]] = None):
callbacks: Optional[Sequence[Callback]] = None,
_remote: bool = None):
"""Runs and blocks until all trials finish.
Examples:
Expand All @@ -523,6 +587,32 @@ def run_experiments(
List of Trial objects, holding data for each executed trial.
"""
if _remote is None:
_remote = ray.util.client.ray.is_connected()

if _remote is True and trial_executor:
raise ValueError("cannot use custom trial executor")

if not trial_executor or isinstance(trial_executor, RayTrialExecutor):
_ray_auto_init()

if _remote:
return ray.get(
ray.remote(num_cpus=0)(run_experiments).remote(
experiments,
scheduler,
server_port,
verbose,
progress_reporter,
resume,
queue_trials,
reuse_actors,
trial_executor,
raise_on_failed_trial,
concurrent,
callbacks,
_remote=False))

# This is important to do this here
# because it schematize the experiments
# and it conducts the implicit registration.
Expand Down Expand Up @@ -557,3 +647,14 @@ def run_experiments(
scheduler=scheduler,
callbacks=callbacks).trials
return trials


def _ray_auto_init():
"""Initialize Ray unless already configured."""
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
elif not ray.is_initialized():
logger.info("Initializing Ray automatically."
"For cluster usage or custom Ray initialization, "
"call `ray.init(...)` before `tune.run`.")
ray.init()

0 comments on commit d335ce2

Please sign in to comment.