forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tune] Add MXNet Gluon example on CIFAR-10 (ray-project#4683)
- Loading branch information
1 parent
481bfbd
commit 584adb4
Showing
2 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import random | ||
|
||
import mxnet as mx | ||
import numpy as np | ||
|
||
from mxnet import gluon, init | ||
from mxnet import autograd as ag | ||
from mxnet.gluon import nn | ||
from mxnet.gluon.data.vision import transforms | ||
from gluoncv.model_zoo import get_model | ||
from gluoncv.data import transforms as gcv_transforms | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description="CIFAR-10 Example") | ||
parser.add_argument( | ||
"--model", | ||
required=True, | ||
type=str, | ||
default="resnet50_v1b", | ||
help="name of the pretrained model from gluoncv model zoo" | ||
"(default: resnet50_v1b).") | ||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=64, | ||
metavar="N", | ||
help="input batch size for training (default: 64)") | ||
parser.add_argument( | ||
"--epochs", | ||
type=int, | ||
default=1, | ||
metavar="N", | ||
help="number of epochs to train (default: 1)") | ||
parser.add_argument( | ||
"--num_gpus", | ||
default=0, | ||
type=int, | ||
help="number of gpus to use, 0 indicates cpu only (default: 0)") | ||
parser.add_argument( | ||
"--num_workers", | ||
default=4, | ||
type=int, | ||
help="number of preprocessing workers (default: 4)") | ||
parser.add_argument( | ||
"--classes", | ||
type=int, | ||
default=10, | ||
metavar="N", | ||
help="number of outputs (default: 10)") | ||
parser.add_argument( | ||
"--lr", | ||
default=0.001, | ||
type=float, | ||
help="initial learning rate (default: 0.001)") | ||
parser.add_argument( | ||
"--momentum", | ||
default=0.9, | ||
type=float, | ||
help="initial momentum (default: 0.9)") | ||
parser.add_argument( | ||
"--wd", default=1e-4, type=float, help="weight decay (default: 1e-4)") | ||
parser.add_argument( | ||
"--expname", type=str, default="cifar10exp", help="experiments location") | ||
parser.add_argument( | ||
"--num_samples", | ||
type=int, | ||
default=20, | ||
metavar="N", | ||
help="number of samples (default: 20)") | ||
parser.add_argument( | ||
"--scheduler", | ||
type=str, | ||
default="fifo", | ||
help="FIFO or AsyncHyperBandScheduler.") | ||
parser.add_argument( | ||
"--seed", | ||
type=int, | ||
default=1, | ||
metavar="S", | ||
help="random seed (default: 1)") | ||
parser.add_argument( | ||
"--smoke_test", action="store_true", help="Finish quickly for testing") | ||
args = parser.parse_args() | ||
|
||
|
||
def train_cifar10(args, config, reporter): | ||
vars(args).update(config) | ||
np.random.seed(args.seed) | ||
random.seed(args.seed) | ||
mx.random.seed(args.seed) | ||
|
||
# Set Hyper-params | ||
batch_size = args.batch_size * max(args.num_gpus, 1) | ||
ctx = [mx.gpu(i) | ||
for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()] | ||
|
||
# Define DataLoader | ||
transform_train = transforms.Compose([ | ||
gcv_transforms.RandomCrop(32, pad=4), | ||
transforms.RandomFlipLeftRight(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.4914, 0.4822, 0.4465], | ||
[0.2023, 0.1994, 0.2010]) | ||
]) | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.4914, 0.4822, 0.4465], | ||
[0.2023, 0.1994, 0.2010]) | ||
]) | ||
|
||
train_data = gluon.data.DataLoader( | ||
gluon.data.vision.CIFAR10(train=True).transform_first(transform_train), | ||
batch_size=batch_size, | ||
shuffle=True, | ||
last_batch="discard", | ||
num_workers=args.num_workers) | ||
|
||
test_data = gluon.data.DataLoader( | ||
gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=args.num_workers) | ||
|
||
# Load model architecture and Initialize the net with pretrained model | ||
finetune_net = get_model(args.model, pretrained=True) | ||
with finetune_net.name_scope(): | ||
finetune_net.fc = nn.Dense(args.classes) | ||
finetune_net.fc.initialize(init.Xavier(), ctx=ctx) | ||
finetune_net.collect_params().reset_ctx(ctx) | ||
finetune_net.hybridize() | ||
|
||
# Define trainer | ||
trainer = gluon.Trainer(finetune_net.collect_params(), "sgd", { | ||
"learning_rate": args.lr, | ||
"momentum": args.momentum, | ||
"wd": args.wd | ||
}) | ||
L = gluon.loss.SoftmaxCrossEntropyLoss() | ||
metric = mx.metric.Accuracy() | ||
|
||
def train(epoch): | ||
for i, batch in enumerate(train_data): | ||
data = gluon.utils.split_and_load( | ||
batch[0], ctx_list=ctx, batch_axis=0, even_split=False) | ||
label = gluon.utils.split_and_load( | ||
batch[1], ctx_list=ctx, batch_axis=0, even_split=False) | ||
with ag.record(): | ||
outputs = [finetune_net(X) for X in data] | ||
loss = [L(yhat, y) for yhat, y in zip(outputs, label)] | ||
for l in loss: | ||
l.backward() | ||
|
||
trainer.step(batch_size) | ||
mx.nd.waitall() | ||
|
||
def test(): | ||
test_loss = 0 | ||
for i, batch in enumerate(test_data): | ||
data = gluon.utils.split_and_load( | ||
batch[0], ctx_list=ctx, batch_axis=0, even_split=False) | ||
label = gluon.utils.split_and_load( | ||
batch[1], ctx_list=ctx, batch_axis=0, even_split=False) | ||
outputs = [finetune_net(X) for X in data] | ||
loss = [L(yhat, y) for yhat, y in zip(outputs, label)] | ||
|
||
test_loss += sum(l.mean().asscalar() for l in loss) / len(loss) | ||
metric.update(label, outputs) | ||
|
||
_, test_acc = metric.get() | ||
test_loss /= len(test_data) | ||
reporter(mean_loss=test_loss, mean_accuracy=test_acc) | ||
|
||
for epoch in range(1, args.epochs + 1): | ||
train(epoch) | ||
test() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parser.parse_args() | ||
|
||
import ray | ||
from ray import tune | ||
from ray.tune.schedulers import AsyncHyperBandScheduler, FIFOScheduler | ||
|
||
ray.init() | ||
if args.scheduler == "fifo": | ||
sched = FIFOScheduler() | ||
elif args.scheduler == "asynchyperband": | ||
sched = AsyncHyperBandScheduler( | ||
time_attr="training_iteration", | ||
reward_attr="neg_mean_loss", | ||
max_t=400, | ||
grace_period=60) | ||
else: | ||
raise NotImplementedError | ||
tune.register_trainable( | ||
"TRAIN_FN", | ||
lambda config, reporter: train_cifar10(args, config, reporter)) | ||
tune.run( | ||
"TRAIN_FN", | ||
name=args.expname, | ||
verbose=2, | ||
scheduler=sched, | ||
**{ | ||
"stop": { | ||
"mean_accuracy": 0.98, | ||
"training_iteration": 1 if args.smoke_test else args.epochs | ||
}, | ||
"resources_per_trial": { | ||
"cpu": int(args.num_workers), | ||
"gpu": int(args.num_gpus) | ||
}, | ||
"num_samples": 1 if args.smoke_test else args.num_samples, | ||
"config": { | ||
"lr": tune.sample_from( | ||
lambda spec: np.power(10.0, np.random.uniform(-4, -1))), | ||
"momentum": tune.sample_from( | ||
lambda spec: np.random.uniform(0.85, 0.95)), | ||
} | ||
}) |