Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] extend PTL template (GPU, typing fixes, tensorboard) #9451

Merged
merged 8 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions doc/source/tune/_tutorials/tune-pytorch-lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,22 @@ The callback just reports some metrics back to Tune after each validation epoch:
:start-after: __tune_callback_begin__
:end-before: __tune_callback_end__

Note that we have to explicitly convert the metrics from a tensor to a Python value.

Adding the Tune training function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Then we specify our training function. Note that we added the ``data_dir`` as a config
parameter here, even though it should not be tuned. We just need to specify it to avoid
Then we specify our training function. Note that we added the ``data_dir`` as a
parameter here to avoid
that each training run downloads the full MNIST dataset. Instead, we want to access
a shared data location.

We are also able to specify the number of epochs to train each model, and the number
of GPUs we want to use for training. We also create a TensorBoard logger that writes
logfiles directly into Tune's root trial directory - if we didn't do that PyTorch
Lightning would create subdirectories, and each trial would thus be shown twice in
TensorBoard, one time for Tune's logs, and another time for PyTorch Lightning's logs.

.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __tune_train_begin__
Expand All @@ -134,7 +142,7 @@ We also delete this data after training to avoid filling up our disk or memory s
:language: python
:start-after: __tune_asha_begin__
:end-before: __tune_asha_end__
:lines: 27
:lines: 36
:dedent: 4

Configuring the search space
Expand All @@ -150,7 +158,7 @@ we are able to also sample small values.
:language: python
:start-after: __tune_asha_begin__
:end-before: __tune_asha_end__
:lines: 4-10
:lines: 5-10
:dedent: 4

Selecting a scheduler
Expand All @@ -165,25 +173,61 @@ configurations.
:language: python
:start-after: __tune_asha_begin__
:end-before: __tune_asha_end__
:lines: 11-16
:lines: 12-17
:dedent: 4


Changing the CLI output
~~~~~~~~~~~~~~~~~~~~~~~

We instantiate a ``CLIReporter`` to specify which metrics we would like to see in our
output tables in the command line. If we didn't specify this, Tune would print all
hyperparameters by default, but since ``data_dir`` is not a real hyperparameter, we
can avoid printing it by omitting it in the ``parameter_columns`` parameter.
output tables in the command line. This is optional, but can be used to make sure our
output tables only include information we would like to see.

.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __tune_asha_begin__
:end-before: __tune_asha_end__
:lines: 19-21
:dedent: 4

Passing constants to the train function
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``data_dir``, ``num_epochs`` and ``num_gpus`` we pass to the training function
are constants. To avoid including them as non-configurable parameters in the ``config``
specification, we can use ``functools.partial`` to wrap around the training function.

.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __tune_asha_begin__
:end-before: __tune_asha_end__
:lines: 17-19
:lines: 24-28
:dedent: 8

Training with GPUs
~~~~~~~~~~~~~~~~~~
We can specify how many resources Tune should request for each trial.
This also includes GPUs.

PyTorch Lightning takes care of moving the training to the GPUs. We
already made sure that our code is compatible with that, so there's
nothing more to do here other than to specify the number of GPUs
we would like to use:

.. literalinclude:: /../../python/ray/tune/examples/mnist_pytorch_lightning.py
:language: python
:start-after: __tune_asha_begin__
:end-before: __tune_asha_end__
:lines: 29
:dedent: 4

Please note that in the current state of PyTorch Lightning, training
on :doc:`fractional GPUs </using-ray-with-gpus>` or
multiple GPUs requires some workarounds. We will address these in a
separate tutorial - for now this example works with no or exactly one
GPU.

Putting it together
~~~~~~~~~~~~~~~~~~~

Expand Down
86 changes: 60 additions & 26 deletions python/ray/tune/examples/mnist_pytorch_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

# __import_tune_begin__
import shutil
from functools import partial
from tempfile import mkdtemp
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from ray import tune
from ray.tune import CLIReporter
Expand Down Expand Up @@ -74,7 +76,7 @@ def training_step(self, train_batch, batch_idx):
loss = self.cross_entropy_loss(logits, y)
accuracy = self.accuracy(logits, y)

logs = {"train_loss": loss, "train_accuracy": accuracy}
logs = {"ptl/train_loss": loss, "ptl/train_accuracy": accuracy}
return {"loss": loss, "log": logs}

def validation_step(self, val_batch, batch_idx):
Expand All @@ -88,12 +90,12 @@ def validation_step(self, val_batch, batch_idx):
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
tensorboard_logs = {"val_loss": avg_loss, "val_accuracy": avg_acc}
logs = {"ptl/val_loss": avg_loss, "ptl/val_accuracy": avg_acc}

return {
"avg_val_loss": avg_loss,
"avg_val_accuracy": avg_acc,
"log": tensorboard_logs
"log": logs
}

@staticmethod
Expand Down Expand Up @@ -133,16 +135,19 @@ def train_mnist(config):
class TuneReportCallback(Callback):
def on_validation_end(self, trainer, pl_module):
tune.report(
loss=trainer.callback_metrics["avg_val_loss"],
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"])
loss=trainer.callback_metrics["avg_val_loss"].item(),
mean_accuracy=trainer.callback_metrics["avg_val_accuracy"].item())
# __tune_callback_end__


# __tune_train_begin__
def train_mnist_tune(config):
model = LightningMNISTClassifier(config, config["data_dir"])
def train_mnist_tune(config, data_dir=None, num_epochs=10, num_gpus=0):
model = LightningMNISTClassifier(config, data_dir)
trainer = pl.Trainer(
max_epochs=10,
max_epochs=num_epochs,
gpus=num_gpus,
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
callbacks=[TuneReportCallback()])

Expand All @@ -160,9 +165,17 @@ def on_validation_end(self, trainer, pl_module):


# __tune_train_checkpoint_begin__
def train_mnist_tune_checkpoint(config, checkpoint=None):
def train_mnist_tune_checkpoint(
config,
checkpoint=None,
data_dir=None,
num_epochs=10,
num_gpus=0):
trainer = pl.Trainer(
max_epochs=10,
max_epochs=num_epochs,
gpus=num_gpus,
logger=TensorBoardLogger(
save_dir=tune.get_trial_dir(), name="", version="."),
progress_bar_refresh_rate=0,
callbacks=[CheckpointCallback(),
TuneReportCallback()])
Expand All @@ -178,54 +191,64 @@ def train_mnist_tune_checkpoint(config, checkpoint=None):
trainer.current_epoch = ckpt["epoch"]
else:
model = LightningMNISTClassifier(
config=config, data_dir=config["data_dir"])
config=config, data_dir=data_dir)

trainer.fit(model)
# __tune_train_checkpoint_end__


# __tune_asha_begin__
def tune_mnist_asha(num_samples=10, max_num_epochs=10):
def tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = mkdtemp(prefix="mnist_data_")
LightningMNISTClassifier.download_data(data_dir)

config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": tune.loguniform(1e-4, 1e-1),
"batch_size": tune.choice([32, 64, 128]),
"data_dir": data_dir
}

scheduler = ASHAScheduler(
metric="loss",
mode="min",
max_t=max_num_epochs,
max_t=num_epochs,
grace_period=1,
reduction_factor=2)

reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])

tune.run(
train_mnist_tune,
resources_per_trial={"cpu": 1},
partial(
train_mnist_tune,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter)
progress_reporter=reporter,
name="tune_mnist_asha")

shutil.rmtree(data_dir)
# __tune_asha_end__


# __tune_pbt_begin__
def tune_mnist_pbt():
def tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0):
data_dir = mkdtemp(prefix="mnist_data_")
LightningMNISTClassifier.download_data(data_dir)

config = {
"layer_1_size": tune.choice([32, 64, 128]),
"layer_2_size": tune.choice([64, 128, 256]),
"lr": 1e-3,
"batch_size": 64,
"data_dir": data_dir
}

scheduler = PopulationBasedTraining(
time_attr="training_iteration",
metric="loss",
Expand All @@ -235,16 +258,24 @@ def tune_mnist_pbt():
"lr": lambda: tune.loguniform(1e-4, 1e-1).func(None),
"batch_size": [32, 64, 128]
})

reporter = CLIReporter(
parameter_columns=["layer_1_size", "layer_2_size", "lr", "batch_size"],
metric_columns=["loss", "mean_accuracy", "training_iteration"])

tune.run(
train_mnist_tune_checkpoint,
resources_per_trial={"cpu": 1},
partial(
train_mnist_tune_checkpoint,
data_dir=data_dir,
num_epochs=num_epochs,
num_gpus=gpus_per_trial),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
num_samples=10,
num_samples=num_samples,
scheduler=scheduler,
progress_reporter=reporter)
progress_reporter=reporter,
name="tune_mnist_pbt")

shutil.rmtree(data_dir)
# __tune_pbt_end__

Expand All @@ -258,7 +289,10 @@ def tune_mnist_pbt():
args, _ = parser.parse_known_args()

if args.smoke_test:
tune_mnist_asha(1, 1)
tune_mnist_asha(num_samples=1, num_epochs=1, gpus_per_trial=0)
tune_mnist_pbt(num_samples=1, num_epochs=1, gpus_per_trial=0)
else:
tune_mnist_asha() # ASHA scheduler
tune_mnist_pbt() # population based training
# ASHA scheduler
tune_mnist_asha(num_samples=10, num_epochs=10, gpus_per_trial=0)
# Population based training
tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0)
10 changes: 6 additions & 4 deletions python/ray/tune/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import random

import numpy as np

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,13 +58,13 @@ def apply_log(_):


def choice(*args, **kwargs):
"""Wraps tune.sample_from around ``np.random.choice``.
"""Wraps tune.sample_from around ``random.choice``.

``tune.choice(10)`` is equivalent to
``tune.sample_from(lambda _: np.random.choice(10))``
``tune.choice([1, 2])`` is equivalent to
``tune.sample_from(lambda _: random.choice([1, 2]))``

"""
return sample_from(lambda _: np.random.choice(*args, **kwargs))
return sample_from(lambda _: random.choice(*args, **kwargs))


def randint(*args, **kwargs):
Expand Down