diff --git a/.travis.yml b/.travis.yml index af31662b650fa..4361f1cb4c680 100644 --- a/.travis.yml +++ b/.travis.yml @@ -376,6 +376,7 @@ matrix: - os: linux env: - PYTHON=3.7 TUNE_TESTING=1 + - INSTALL_HOROVOD=1 - TF_VERSION=2.1.0 - TFP_VERSION=0.8 - TORCH_VERSION=1.5 diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index 248e756f33b0c..b0d10be6c7a0b 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -293,6 +293,12 @@ install_dependencies() { pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt fi + # Additional Tune dependency for Horovod. + if [ "${INSTALL_HOROVOD-}" = 1 ]; then + # TODO: eventually pin this to master. + HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git + fi + # Additional RaySGD test dependencies. if [ "${SGD_TESTING-}" = 1 ]; then pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt @@ -303,6 +309,7 @@ install_dependencies() { pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt fi + # If CI has deemed that a different version of Tensorflow or Torch # should be installed, then upgrade/downgrade to that specific version. if [ -n "${TORCH_VERSION-}" ] || [ -n "${TFP_VERSION-}" ] || [ -n "${TF_VERSION-}" ]; then diff --git a/doc/source/conf.py b/doc/source/conf.py index 9c8567f0cce90..a350a281013da 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -36,6 +36,8 @@ def __getattr__(cls, name): "blist", "gym", "gym.spaces", + "horovod", + "horovod.ray", "kubernetes", "psutil", "ray._raylet", diff --git a/doc/source/tune/api_docs/integration.rst b/doc/source/tune/api_docs/integration.rst index d03cba0cc8606..ef29d89b36516 100644 --- a/doc/source/tune/api_docs/integration.rst +++ b/doc/source/tune/api_docs/integration.rst @@ -43,6 +43,14 @@ Torch (tune.integration.torch) .. autofunction:: ray.tune.integration.torch.is_distributed_trainable + +.. _tune-integration-horovod: + +Horovod (tune.integration.horovod) +---------------------------------- + +.. autofunction:: ray.tune.integration.horovod.DistributedTrainableCreator + .. _tune-integration-wandb: Weights and Biases (tune.integration.wandb) diff --git a/doc/source/tune/examples/horovod_simple.rst b/doc/source/tune/examples/horovod_simple.rst new file mode 100644 index 0000000000000..31d6fc66c7f38 --- /dev/null +++ b/doc/source/tune/examples/horovod_simple.rst @@ -0,0 +1,6 @@ +:orphan: + +horovod_simple +~~~~~~~~~~~~~~ + +.. literalinclude:: /../../python/ray/tune/examples/horovod_simple.py \ No newline at end of file diff --git a/doc/source/tune/examples/index.rst b/doc/source/tune/examples/index.rst index 5c4f44c84b379..176c02686ca0f 100644 --- a/doc/source/tune/examples/index.rst +++ b/doc/source/tune/examples/index.rst @@ -35,6 +35,10 @@ Tensorflow/Keras Examples - :doc:`/tune/examples/pbt_memnn_example`: Example of training a Memory NN on bAbI with Keras using PBT. - :doc:`/tune/examples/tf_mnist_example`: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses `tf.function`. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced +Horovod Example +~~~~~~~~~~~~~~~ +- :doc:`/tune/examples/horovod_simple`: Leverages the :ref:`Horovod-Tune ` integration to launch a distributed training + tuning job. + PyTorch Examples ~~~~~~~~~~~~~~~~ diff --git a/python/ray/tune/BUILD b/python/ray/tune/BUILD index 820ffd21baa50..af17d7c426cdf 100644 --- a/python/ray/tune/BUILD +++ b/python/ray/tune/BUILD @@ -327,6 +327,14 @@ py_test( deps = [":tune_lib"], ) +py_test( + name = "test_horovod", + size = "medium", + srcs = ["tests/test_horovod.py"], + tags = ["exclusive", "example", "py37"], + deps = [":tune_lib"], +) + py_test( name = "ddp_mnist_torch", size = "small", @@ -364,6 +372,15 @@ py_test( args = ["--smoke-test"] ) +py_test( + name = "horovod_simple", + size = "medium", + srcs = ["examples/horovod_simple.py"], + tags = ["exclusive", "example", "py37"], + deps = [":tune_lib"], + args = ["--smoke-test"] +) + py_test( name = "hyperband_example", size = "medium", diff --git a/python/ray/tune/examples/horovod_simple.py b/python/ray/tune/examples/horovod_simple.py new file mode 100644 index 0000000000000..e566edf0f07c8 --- /dev/null +++ b/python/ray/tune/examples/horovod_simple.py @@ -0,0 +1,115 @@ +import torch +import numpy as np +from ray import tune +from ray.tune.integration.horovod import DistributedTrainableCreator +import time + + +def sq(x): + m2 = 1. + m1 = -20. + m0 = 50. + return m2 * x * x + m1 * x + m0 + + +def qu(x): + m3 = 10. + m2 = 5. + m1 = -20. + m0 = -5. + return m3 * x * x * x + m2 * x * x + m1 * x + m0 + + +class Net(torch.nn.Module): + def __init__(self, mode="sq"): + super(Net, self).__init__() + + if mode == "square": + self.mode = 0 + self.param = torch.nn.Parameter(torch.FloatTensor([1., -1.])) + else: + self.mode = 1 + self.param = torch.nn.Parameter(torch.FloatTensor([1., -1., 1.])) + + def forward(self, x): + if ~self.mode: + return x * x + self.param[0] * x + self.param[1] + else: + return_val = 10 * x * x * x + return_val += self.param[0] * x * x + return_val += self.param[1] * x + self.param[2] + return return_val + + +def train(config): + import torch + import horovod.torch as hvd + hvd.init() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = Net(args.mode).to(device) + optimizer = torch.optim.SGD( + net.parameters(), + lr=config["lr"], + ) + optimizer = hvd.DistributedOptimizer(optimizer) + + num_steps = 5 + print(hvd.size()) + np.random.seed(1 + hvd.rank()) + torch.manual_seed(1234) + # To ensure consistent initialization across slots, + hvd.broadcast_parameters(net.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + start = time.time() + for step in range(1, num_steps + 1): + features = torch.Tensor( + np.random.rand(1) * 2 * args.x_max - args.x_max).to(device) + if args.mode == "square": + labels = sq(features) + else: + labels = qu(features) + optimizer.zero_grad() + outputs = net(features) + loss = torch.nn.MSELoss()(outputs, labels) + loss.backward() + + optimizer.step() + time.sleep(0.1) + tune.report(loss=loss.item()) + total = time.time() - start + print(f"Took {total:0.3f} s. Avg: {total / num_steps:0.3f} s.") + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", type=str, default="square", choices=["square", "cubic"]) + parser.add_argument( + "--learning_rate", type=float, default=0.1, dest="learning_rate") + parser.add_argument("--x_max", type=float, default=1., dest="x_max") + parser.add_argument("--gpu", action="store_true") + parser.add_argument( + "--smoke-test", + action="store_true", + help=("Finish quickly for testing.")) + parser.add_argument("--hosts-per-trial", type=int, default=1) + parser.add_argument("--slots-per-host", type=int, default=2) + args = parser.parse_args() + + # import ray + # ray.init(address="auto") # assumes ray is started with ray up + + horovod_trainable = DistributedTrainableCreator( + train, + use_gpu=args.gpu, + num_hosts=args.hosts_per_trial, + num_slots=args.slots_per_host, + replicate_pem=False) + analysis = tune.run( + horovod_trainable, + config={"lr": tune.uniform(0.1, 1)}, + num_samples=2 if args.smoke_test else 10, + fail_fast=True) + config = analysis.get_best_config(metric="loss", mode="min") diff --git a/python/ray/tune/integration/horovod.py b/python/ray/tune/integration/horovod.py new file mode 100644 index 0000000000000..b718e4ff736c4 --- /dev/null +++ b/python/ray/tune/integration/horovod.py @@ -0,0 +1,226 @@ +import os +import logging +from filelock import FileLock + +import ray +from ray import tune +from ray.tune.resources import Resources +from ray.tune.trainable import TrainableUtil +from ray.tune.result import RESULT_DUPLICATE +from ray.tune.logger import NoopLogger + +from ray.tune.function_runner import wrap_function +from horovod.ray import RayExecutor + +logger = logging.getLogger(__name__) + + +def get_rank(): + return os.environ["HOROVOD_RANK"] + + +def logger_creator(log_config, logdir): + """Simple NOOP logger for worker trainables.""" + index = get_rank() + worker_dir = os.path.join(logdir, "worker_{}".format(index)) + os.makedirs(worker_dir, exist_ok=True) + return NoopLogger(log_config, worker_dir) + + +class _HorovodTrainable(tune.Trainable): + """Abstract Trainable class for Horovod.""" + # Callable function for training. + _function = None + # Number of hosts (nodes) to allocate per trial + _num_hosts: int = 1 + # Number of workers (slots) to place on each host. + _num_slots: int = 1 + # Number of CPU resources to reserve for each worker. + _num_cpus_per_slot: int = 1 + # Whether to reserve and pass GPU resources through. + _use_gpu: bool = False + # bool: Whether a the function has completed training + _finished: bool = False + + # Horovod settings + _ssh_str: str = None + _ssh_identity_file: str = None + _timeout_s: int = 30 + + @property + def num_workers(self): + return self._num_hosts * self._num_slots + + def setup(self, config): + trainable = wrap_function(self.__class__._function) + # We use a filelock here to ensure that the file-writing + # process is safe across different trainables. + if self._ssh_identity_file: + with FileLock(self._ssh_identity_file + ".lock"): + settings = RayExecutor.create_settings( + self._timeout_s, self._ssh_identity_file, self._ssh_str) + else: + settings = RayExecutor.create_settings( + self._timeout_s, self._ssh_identity_file, self._ssh_str) + + self.executor = RayExecutor( + settings, + cpus_per_slot=self._num_cpus_per_slot, + use_gpu=self._use_gpu, + num_hosts=self._num_hosts, + num_slots=self._num_slots) + + # We can't put `self` in the lambda closure, so we + # resolve the variable ahead of time. + logdir_ = str(self.logdir) + + # Starts the workers as specified by the resources above. + self.executor.start( + executable_cls=trainable, + executable_kwargs={ + "config": config, + "logger_creator": lambda cfg: logger_creator(cfg, logdir_) + }) + + def step(self): + if self._finished: + raise RuntimeError("Training has already finished.") + result = self.executor.execute(lambda w: w.step())[0] + if RESULT_DUPLICATE in result: + self._finished = True + return result + + def save_checkpoint(self, checkpoint_dir): + # TODO: optimize if colocated + save_obj = self.executor.execute_single(lambda w: w.save_to_object()) + checkpoint_path = TrainableUtil.create_from_pickle( + save_obj, checkpoint_dir) + return checkpoint_path + + def load_checkpoint(self, checkpoint_dir): + checkpoint_obj = TrainableUtil.checkpoint_to_object(checkpoint_dir) + x_id = ray.put(checkpoint_obj) + return self.executor.execute(lambda w: w.restore_from_object(x_id)) + + def stop(self): + self.executor.execute(lambda w: w.stop()) + self.executor.shutdown() + + +def DistributedTrainableCreator(func, + use_gpu=False, + num_hosts=1, + num_slots=1, + num_cpus_per_slot=1, + timeout_s=30, + replicate_pem=False): + """Converts Horovod functions to be executable by Tune. + + Requires horovod > 0.19 to work. + + This function wraps and sets the resources for a given Horovod + function to be used with Tune. It generates a Horovod Trainable (trial) + which can itself be a distributed training job. One basic assumption of + this implementation is that all sub-workers + of a trial will be placed evenly across different machines. + + It is recommended that if `num_hosts` per trial > 1, you set + num_slots == the size (or number of GPUs) of a single host. + If num_hosts == 1, then you can set num_slots to be <= + the size (number of GPUs) of a single host. + + This above assumption can be relaxed - please file a feature request + on Github to inform the maintainers. + + Another assumption is that this API requires gloo as the underlying + communication primitive. You will need to install Horovod with + `HOROVOD_WITH_GLOO` enabled. + + *Fault Tolerance:* The trial workers themselves are not fault tolerant. + When a host of a trial fails, all workers of a trial are expected to + die, and the trial is expected to restart. This currently does not + support function checkpointing. + + Args: + func (Callable[[dict], None]): A training function that takes in + a config dict for hyperparameters and should initialize + horovod via horovod.init. + use_gpu (bool); Whether to allocate a GPU per worker. + num_cpus_per_slot (int): Number of CPUs to request + from Ray per worker. + num_hosts (int): Number of hosts that each trial is expected + to use. + num_slots (int): Number of slots (workers) to start on each host. + timeout_s (int): Seconds for Horovod rendezvous to timeout. + replicate_pem (bool): THIS MAY BE INSECURE. If true, this will + replicate the underlying Ray cluster ssh key across all hosts. + This may be useful if using the Ray Autoscaler. + + + Returns: + Trainable class that can be passed into `tune.run`. + + Example: + + .. code-block:: python + + def train(config): + horovod.init() + horovod.allreduce() + + from ray.tune.integration.horovod import DistributedTrainableCreator + trainable_cls = DistributedTrainableCreator( + train, num_hosts=1, num_slots=2, use_gpu=True) + + tune.run(trainable_cls) + + .. versionadded:: 1.0.0 + """ + ssh_identity_file = None + sshkeystr = None + + if replicate_pem: + from ray.tune.cluster_info import get_ssh_key + ssh_identity_file = get_ssh_key() + if os.path.exists(ssh_identity_file): + # For now, we assume that you're on a Ray cluster. + with open(ssh_identity_file) as f: + sshkeystr = f.read() + + class WrappedHorovodTrainable(_HorovodTrainable): + _function = func + _num_hosts = num_hosts + _num_slots = num_slots + _num_cpus_per_slot = num_cpus_per_slot + _use_gpu = use_gpu + _ssh_identity_file = ssh_identity_file + _ssh_str = sshkeystr + _timeout_s = timeout_s + + @classmethod + def default_resource_request(cls, config): + extra_gpu = int(num_hosts * num_slots) * int(use_gpu) + extra_cpu = int(num_hosts * num_slots * num_cpus_per_slot) + + return Resources( + cpu=0, + gpu=0, + extra_cpu=extra_cpu, + extra_gpu=extra_gpu, + ) + + return WrappedHorovodTrainable + + +# pytest presents a bunch of serialization problems +# that force us to include mocks as part of the module. + + +def _train_simple(config): + import horovod.torch as hvd + hvd.init() + from ray import tune + for i in range(config.get("epochs", 2)): + import time + time.sleep(1) + tune.report(test=1, rank=hvd.rank()) diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 230929ff11a3b..d9d233a047bbe 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -235,8 +235,8 @@ class TuneReportCheckpointCallback(TuneCallback): .. code-block:: python import pytorch_lightning as pl - from ray.tune.integration.pytorch_lightning import \ - TuneReportCheckpointCallback + from ray.tune.integration.pytorch_lightning import ( + TuneReportCheckpointCallback) # Save checkpoint after each training batch and after each # validation epoch. diff --git a/python/ray/tune/tests/test_horovod.py b/python/ray/tune/tests/test_horovod.py new file mode 100644 index 0000000000000..fc71fc399c073 --- /dev/null +++ b/python/ray/tune/tests/test_horovod.py @@ -0,0 +1,97 @@ +import pytest + +import ray +from ray import tune + +pytest.importorskip("horovod") + +try: + from ray.tune.integration.horovod import (DistributedTrainableCreator, + _train_simple) +except ImportError: + pass # This shouldn't be reached - the test should be skipped. + + +@pytest.fixture +def ray_start_2_cpus(): + address_info = ray.init(num_cpus=2) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +@pytest.fixture +def ray_connect_cluster(): + try: + address_info = ray.init(address="auto") + except Exception as e: + pytest.skip(str(e)) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def test_single_step(ray_start_2_cpus): + trainable_cls = DistributedTrainableCreator( + _train_simple, num_hosts=1, num_slots=2) + trainer = trainable_cls() + trainer.train() + trainer.stop() + + +def test_step_after_completion(ray_start_2_cpus): + trainable_cls = DistributedTrainableCreator( + _train_simple, num_hosts=1, num_slots=2) + trainer = trainable_cls(config={"epochs": 1}) + with pytest.raises(RuntimeError): + for i in range(10): + trainer.train() + + +def test_validation(ray_start_2_cpus): + def bad_func(a, b, c): + return 1 + + t_cls = DistributedTrainableCreator(bad_func, num_slots=2) + with pytest.raises(ValueError): + t_cls() + + +def test_set_global(ray_start_2_cpus): + trainable_cls = DistributedTrainableCreator(_train_simple, num_slots=2) + trainable = trainable_cls() + result = trainable.train() + trainable.stop() + assert result["rank"] == 0 + + +def test_simple_tune(ray_start_4_cpus): + trainable_cls = DistributedTrainableCreator(_train_simple, num_slots=2) + analysis = tune.run( + trainable_cls, num_samples=2, stop={"training_iteration": 2}) + assert analysis.trials[0].last_result["training_iteration"] == 2 + + +@pytest.mark.parametrize("use_gpu", [True, False]) +def test_resource_tune(ray_connect_cluster, use_gpu): + if use_gpu and ray.cluster_resources().get("GPU", 0) == 0: + pytest.skip("No GPU available.") + trainable_cls = DistributedTrainableCreator( + _train_simple, num_slots=2, use_gpu=use_gpu) + analysis = tune.run( + trainable_cls, num_samples=2, stop={"training_iteration": 2}) + assert analysis.trials[0].last_result["training_iteration"] == 2 + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))