Skip to content

Commit

Permalink
[tune] horovod trainable (ray-project#10304)
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw authored Sep 3, 2020
1 parent 7068c63 commit 43a7a64
Show file tree
Hide file tree
Showing 11 changed files with 485 additions and 2 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ matrix:
- os: linux
env:
- PYTHON=3.7 TUNE_TESTING=1
- INSTALL_HOROVOD=1
- TF_VERSION=2.1.0
- TFP_VERSION=0.8
- TORCH_VERSION=1.5
Expand Down
7 changes: 7 additions & 0 deletions ci/travis/install-dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,12 @@ install_dependencies() {
pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt
fi

# Additional Tune dependency for Horovod.
if [ "${INSTALL_HOROVOD-}" = 1 ]; then
# TODO: eventually pin this to master.
HOROVOD_WITH_GLOO=1 HOROVOD_WITHOUT_MPI=1 HOROVOD_WITHOUT_MXNET=1 pip install -U git+https://github.com/horovod/horovod.git
fi

# Additional RaySGD test dependencies.
if [ "${SGD_TESTING-}" = 1 ]; then
pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt
Expand All @@ -303,6 +309,7 @@ install_dependencies() {
pip install -r "${WORKSPACE_DIR}"/python/requirements_tune.txt
fi


# If CI has deemed that a different version of Tensorflow or Torch
# should be installed, then upgrade/downgrade to that specific version.
if [ -n "${TORCH_VERSION-}" ] || [ -n "${TFP_VERSION-}" ] || [ -n "${TF_VERSION-}" ]; then
Expand Down
2 changes: 2 additions & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __getattr__(cls, name):
"blist",
"gym",
"gym.spaces",
"horovod",
"horovod.ray",
"kubernetes",
"psutil",
"ray._raylet",
Expand Down
8 changes: 8 additions & 0 deletions doc/source/tune/api_docs/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ Torch (tune.integration.torch)

.. autofunction:: ray.tune.integration.torch.is_distributed_trainable


.. _tune-integration-horovod:

Horovod (tune.integration.horovod)
----------------------------------

.. autofunction:: ray.tune.integration.horovod.DistributedTrainableCreator

.. _tune-integration-wandb:

Weights and Biases (tune.integration.wandb)
Expand Down
6 changes: 6 additions & 0 deletions doc/source/tune/examples/horovod_simple.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
:orphan:

horovod_simple
~~~~~~~~~~~~~~

.. literalinclude:: /../../python/ray/tune/examples/horovod_simple.py
4 changes: 4 additions & 0 deletions doc/source/tune/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ Tensorflow/Keras Examples
- :doc:`/tune/examples/pbt_memnn_example`: Example of training a Memory NN on bAbI with Keras using PBT.
- :doc:`/tune/examples/tf_mnist_example`: Converts the Advanced TF2.0 MNIST example to use Tune with the Trainable. This uses `tf.function`. Original code from tensorflow: https://www.tensorflow.org/tutorials/quickstart/advanced

Horovod Example
~~~~~~~~~~~~~~~
- :doc:`/tune/examples/horovod_simple`: Leverages the :ref:`Horovod-Tune <tune-integration-horovod>` integration to launch a distributed training + tuning job.


PyTorch Examples
~~~~~~~~~~~~~~~~
Expand Down
17 changes: 17 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,14 @@ py_test(
deps = [":tune_lib"],
)

py_test(
name = "test_horovod",
size = "medium",
srcs = ["tests/test_horovod.py"],
tags = ["exclusive", "example", "py37"],
deps = [":tune_lib"],
)

py_test(
name = "ddp_mnist_torch",
size = "small",
Expand Down Expand Up @@ -364,6 +372,15 @@ py_test(
args = ["--smoke-test"]
)

py_test(
name = "horovod_simple",
size = "medium",
srcs = ["examples/horovod_simple.py"],
tags = ["exclusive", "example", "py37"],
deps = [":tune_lib"],
args = ["--smoke-test"]
)

py_test(
name = "hyperband_example",
size = "medium",
Expand Down
115 changes: 115 additions & 0 deletions python/ray/tune/examples/horovod_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import numpy as np
from ray import tune
from ray.tune.integration.horovod import DistributedTrainableCreator
import time


def sq(x):
m2 = 1.
m1 = -20.
m0 = 50.
return m2 * x * x + m1 * x + m0


def qu(x):
m3 = 10.
m2 = 5.
m1 = -20.
m0 = -5.
return m3 * x * x * x + m2 * x * x + m1 * x + m0


class Net(torch.nn.Module):
def __init__(self, mode="sq"):
super(Net, self).__init__()

if mode == "square":
self.mode = 0
self.param = torch.nn.Parameter(torch.FloatTensor([1., -1.]))
else:
self.mode = 1
self.param = torch.nn.Parameter(torch.FloatTensor([1., -1., 1.]))

def forward(self, x):
if ~self.mode:
return x * x + self.param[0] * x + self.param[1]
else:
return_val = 10 * x * x * x
return_val += self.param[0] * x * x
return_val += self.param[1] * x + self.param[2]
return return_val


def train(config):
import torch
import horovod.torch as hvd
hvd.init()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = Net(args.mode).to(device)
optimizer = torch.optim.SGD(
net.parameters(),
lr=config["lr"],
)
optimizer = hvd.DistributedOptimizer(optimizer)

num_steps = 5
print(hvd.size())
np.random.seed(1 + hvd.rank())
torch.manual_seed(1234)
# To ensure consistent initialization across slots,
hvd.broadcast_parameters(net.state_dict(), root_rank=0)
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

start = time.time()
for step in range(1, num_steps + 1):
features = torch.Tensor(
np.random.rand(1) * 2 * args.x_max - args.x_max).to(device)
if args.mode == "square":
labels = sq(features)
else:
labels = qu(features)
optimizer.zero_grad()
outputs = net(features)
loss = torch.nn.MSELoss()(outputs, labels)
loss.backward()

optimizer.step()
time.sleep(0.1)
tune.report(loss=loss.item())
total = time.time() - start
print(f"Took {total:0.3f} s. Avg: {total / num_steps:0.3f} s.")


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--mode", type=str, default="square", choices=["square", "cubic"])
parser.add_argument(
"--learning_rate", type=float, default=0.1, dest="learning_rate")
parser.add_argument("--x_max", type=float, default=1., dest="x_max")
parser.add_argument("--gpu", action="store_true")
parser.add_argument(
"--smoke-test",
action="store_true",
help=("Finish quickly for testing."))
parser.add_argument("--hosts-per-trial", type=int, default=1)
parser.add_argument("--slots-per-host", type=int, default=2)
args = parser.parse_args()

# import ray
# ray.init(address="auto") # assumes ray is started with ray up

horovod_trainable = DistributedTrainableCreator(
train,
use_gpu=args.gpu,
num_hosts=args.hosts_per_trial,
num_slots=args.slots_per_host,
replicate_pem=False)
analysis = tune.run(
horovod_trainable,
config={"lr": tune.uniform(0.1, 1)},
num_samples=2 if args.smoke_test else 10,
fail_fast=True)
config = analysis.get_best_config(metric="loss", mode="min")
Loading

0 comments on commit 43a7a64

Please sign in to comment.