Skip to content

Commit

Permalink
[tune] Add MXNet Gluon example on CIFAR-10 (ray-project#4683)
Browse files Browse the repository at this point in the history
  • Loading branch information
cgraywang authored and richardliaw committed May 9, 2019
1 parent 481bfbd commit 584adb4
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/ray/tune/examples/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,5 @@ Contributed Examples
A contributed example of tuning a Keras model on CIFAR10 with the PopulationBasedTraining scheduler.
- `genetic_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/genetic_example.py>`__:
Optimizing the michalewicz function using the contributed GeneticSearch search algorithm with AsyncHyperBandScheduler.
- `tune_cifar10_gluon <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/tune_cifar10_gluon.py>`__:
MXNet Gluon example to use Tune with the function-based API on CIFAR-10 dataset.
224 changes: 224 additions & 0 deletions python/ray/tune/examples/tune_cifar10_gluon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from __future__ import print_function

import argparse
import random

import mxnet as mx
import numpy as np

from mxnet import gluon, init
from mxnet import autograd as ag
from mxnet.gluon import nn
from mxnet.gluon.data.vision import transforms
from gluoncv.model_zoo import get_model
from gluoncv.data import transforms as gcv_transforms

# Training settings
parser = argparse.ArgumentParser(description="CIFAR-10 Example")
parser.add_argument(
"--model",
required=True,
type=str,
default="resnet50_v1b",
help="name of the pretrained model from gluoncv model zoo"
"(default: resnet50_v1b).")
parser.add_argument(
"--batch_size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)")
parser.add_argument(
"--epochs",
type=int,
default=1,
metavar="N",
help="number of epochs to train (default: 1)")
parser.add_argument(
"--num_gpus",
default=0,
type=int,
help="number of gpus to use, 0 indicates cpu only (default: 0)")
parser.add_argument(
"--num_workers",
default=4,
type=int,
help="number of preprocessing workers (default: 4)")
parser.add_argument(
"--classes",
type=int,
default=10,
metavar="N",
help="number of outputs (default: 10)")
parser.add_argument(
"--lr",
default=0.001,
type=float,
help="initial learning rate (default: 0.001)")
parser.add_argument(
"--momentum",
default=0.9,
type=float,
help="initial momentum (default: 0.9)")
parser.add_argument(
"--wd", default=1e-4, type=float, help="weight decay (default: 1e-4)")
parser.add_argument(
"--expname", type=str, default="cifar10exp", help="experiments location")
parser.add_argument(
"--num_samples",
type=int,
default=20,
metavar="N",
help="number of samples (default: 20)")
parser.add_argument(
"--scheduler",
type=str,
default="fifo",
help="FIFO or AsyncHyperBandScheduler.")
parser.add_argument(
"--seed",
type=int,
default=1,
metavar="S",
help="random seed (default: 1)")
parser.add_argument(
"--smoke_test", action="store_true", help="Finish quickly for testing")
args = parser.parse_args()


def train_cifar10(args, config, reporter):
vars(args).update(config)
np.random.seed(args.seed)
random.seed(args.seed)
mx.random.seed(args.seed)

# Set Hyper-params
batch_size = args.batch_size * max(args.num_gpus, 1)
ctx = [mx.gpu(i)
for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()]

# Define DataLoader
transform_train = transforms.Compose([
gcv_transforms.RandomCrop(32, pad=4),
transforms.RandomFlipLeftRight(),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],
[0.2023, 0.1994, 0.2010])
])

train_data = gluon.data.DataLoader(
gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
batch_size=batch_size,
shuffle=True,
last_batch="discard",
num_workers=args.num_workers)

test_data = gluon.data.DataLoader(
gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers)

# Load model architecture and Initialize the net with pretrained model
finetune_net = get_model(args.model, pretrained=True)
with finetune_net.name_scope():
finetune_net.fc = nn.Dense(args.classes)
finetune_net.fc.initialize(init.Xavier(), ctx=ctx)
finetune_net.collect_params().reset_ctx(ctx)
finetune_net.hybridize()

# Define trainer
trainer = gluon.Trainer(finetune_net.collect_params(), "sgd", {
"learning_rate": args.lr,
"momentum": args.momentum,
"wd": args.wd
})
L = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

def train(epoch):
for i, batch in enumerate(train_data):
data = gluon.utils.split_and_load(
batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
label = gluon.utils.split_and_load(
batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
with ag.record():
outputs = [finetune_net(X) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
for l in loss:
l.backward()

trainer.step(batch_size)
mx.nd.waitall()

def test():
test_loss = 0
for i, batch in enumerate(test_data):
data = gluon.utils.split_and_load(
batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
label = gluon.utils.split_and_load(
batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
outputs = [finetune_net(X) for X in data]
loss = [L(yhat, y) for yhat, y in zip(outputs, label)]

test_loss += sum(l.mean().asscalar() for l in loss) / len(loss)
metric.update(label, outputs)

_, test_acc = metric.get()
test_loss /= len(test_data)
reporter(mean_loss=test_loss, mean_accuracy=test_acc)

for epoch in range(1, args.epochs + 1):
train(epoch)
test()


if __name__ == "__main__":
args = parser.parse_args()

import ray
from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler, FIFOScheduler

ray.init()
if args.scheduler == "fifo":
sched = FIFOScheduler()
elif args.scheduler == "asynchyperband":
sched = AsyncHyperBandScheduler(
time_attr="training_iteration",
reward_attr="neg_mean_loss",
max_t=400,
grace_period=60)
else:
raise NotImplementedError
tune.register_trainable(
"TRAIN_FN",
lambda config, reporter: train_cifar10(args, config, reporter))
tune.run(
"TRAIN_FN",
name=args.expname,
verbose=2,
scheduler=sched,
**{
"stop": {
"mean_accuracy": 0.98,
"training_iteration": 1 if args.smoke_test else args.epochs
},
"resources_per_trial": {
"cpu": int(args.num_workers),
"gpu": int(args.num_gpus)
},
"num_samples": 1 if args.smoke_test else args.num_samples,
"config": {
"lr": tune.sample_from(
lambda spec: np.power(10.0, np.random.uniform(-4, -1))),
"momentum": tune.sample_from(
lambda spec: np.random.uniform(0.85, 0.95)),
}
})

0 comments on commit 584adb4

Please sign in to comment.