Skip to content

Commit

Permalink
nn fixes (apache#6792)
Browse files Browse the repository at this point in the history
* nn fixes

* restore contrib/autograd

* update examples and revert container changes

* fix resnet version

* fix benchmark
  • Loading branch information
szha authored and piiswrong committed Jul 12, 2017
1 parent 944a725 commit a887e11
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 29 deletions.
2 changes: 1 addition & 1 deletion example/autograd/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import mxnet.ndarray as F
from mxnet import foo
from mxnet.foo import nn
from mxnet.contrib import autograd
from mxnet import autograd


parser = argparse.ArgumentParser(description='MXNet actor-critic example')
Expand Down
2 changes: 1 addition & 1 deletion example/autograd/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import mxnet as mx
from mxnet import foo
from mxnet.foo import nn
from mxnet.contrib import autograd
from mxnet import autograd
from data import cifar10_iterator


Expand Down
2 changes: 1 addition & 1 deletion example/autograd/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mxnet.foo import nn
import numpy as np
import logging
from mxnet.contrib import autograd as ag
from mxnet import autograd as ag
logging.basicConfig(level=logging.DEBUG)

# define network
Expand Down
99 changes: 75 additions & 24 deletions example/autograd/resnet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
from __future__ import division

import time
import argparse, time
import logging
logging.basicConfig(level=logging.INFO)

import mxnet as mx
from mxnet import foo
from mxnet.foo import nn
from mxnet.contrib import autograd as ag
from mxnet import autograd as ag

from data import *

# CLI
parser = argparse.ArgumentParser(description='Train a resnet model for image classification.')
parser.add_argument('--dataset', type=str, default='dummy', help='dataset to use. options are mnist, cifar10, and dummy.')
parser.add_argument('--batch_size', type=int, default=32, help='training batch size per device (CPU/GPU).')
parser.add_argument('--resnet_version', type=int, default=1, help='version of resnet to use. options are 1 and 2. default is 1.')
parser.add_argument('--resnet_layers', type=int, default=50, help='layers of resnet to use. options are 18, 50. default is 50.')
parser.add_argument('--gpus', type=int, default=0, help='number of gpus to use.')
parser.add_argument('--epochs', type=int, default=3, help='number of training epochs.')
parser.add_argument('--lr', type=float, default=0.01, help='learning Rate. default is 0.01.')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123.')
parser.add_argument('--thumbnail', action='store_true', default=False, help='use thumbnail or not. default is false.')
parser.add_argument('--benchmark', action='store_true', default=True, help='whether to run benchmark.')
parser.add_argument('--symbolic', action='store_true', default=False, help='whether to train in symbolic way with module.')
opt = parser.parse_args()

print(opt)

def conv3x3(filters, stride, in_filters):
return nn.Conv2D(filters, kernel_size=3, strides=stride, padding=1,
use_bias=False, in_filters=in_filters)
Expand Down Expand Up @@ -248,17 +269,47 @@ def hybrid_forward(self, F, x):

return x

# construct net
resnet_spec = { 18: ('basic_block', [2, 2, 2], [16, 16, 32, 64]),
34: ('basic_block', [3, 4, 6, 3], [16, 16, 32, 64]),
50: ('bottle_neck', [3, 4, 6, 3], [64, 256, 512, 1024, 2048]),
101: ('bottle_neck', [3, 4, 23, 3], [64, 256, 512, 1024, 2048]),
152: ('bottle_neck', [3, 8, 36, 3], [64, 256, 512, 1024, 2048]) }

resnet_net_versions = [ResnetV1, ResnetV2]
resnet_block_versions = [{'basic_block': BasicBlockV1, 'bottle_neck': BottleneckV1},
{'basic_block': BasicBlockV2, 'bottle_neck': BottleneckV2}]

def get_resnet(version, num_layers, classes, use_thumbnail):
block_type, layers, filters = resnet_spec[num_layers]
resnet = resnet_net_versions[version]
block = resnet_block_versions[version][block_type]
return resnet(block, classes, layers, filters, use_thumbnail)

dataset_classes = {'mnist': 10, 'cifar10': 10, 'imagenet': 1000, 'dummy': 1000}

batch_size, dataset, classes = opt.batch_size, opt.dataset, dataset_classes[opt.dataset]

gpus, version = opt.gpus, opt.resnet_version-1

if opt.benchmark:
batch_size = 32
dataset = 'dummy'
classes = 1000
version = 0


net = get_resnet(version, opt.resnet_layers, classes, opt.thumbnail)

def resnet18v2_cifar(classes):
return ResnetV2(BasicBlockV2, classes, [2, 2, 2], [16, 16, 32, 64], True)
def resnet50v1_imagenet(classes):
return ResnetV1(BottleneckV1, classes, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], False)
def resnet50v2_imagenet(classes):
return ResnetV2(BottleneckV2, classes, [3, 4, 6, 3], [64, 256, 512, 1024, 2048], False)
batch_size *= max(1, gpus)

net = resnet18v2_cifar(10)
batch_size = 32*8
train_data, val_data = cifar10_iterator(batch_size, (3, 32, 32))
# get dataset iterators
if dataset == 'mnist':
train_data, val_data = mnist_iterator(batch_size, (1, 32, 32))
elif dataset == 'cifar10':
train_data, val_data = cifar10_iterator(batch_size, (3, 32, 32))
elif dataset == 'dummy':
train_data, val_data = dummy_iterator(batch_size, (3, 224, 224))

def test(ctx):
metric = mx.metric.Accuracy()
Expand All @@ -270,7 +321,7 @@ def test(ctx):
for x in data:
outputs.append(net(x))
metric.update(label, outputs)
print('validation acc: %s=%f'%metric.get())
logging.info('validation acc: %s=%f'%metric.get())


def train(epoch, ctx):
Expand Down Expand Up @@ -299,24 +350,24 @@ def train(epoch, ctx):
loss.backward()
trainer.step(batch.data[0].shape[0])
metric.update(label, outputs)
print('speed: {} samples/s'.format(batch.data[0].shape[0]/(time.time()-btic)))
logging.info('speed: {} samples/s'.format(batch_size/(time.time()-btic)))
btic = time.time()

name, acc = metric.get()
metric.reset()
print('training acc at epoch %d: %s=%f'%(i, name, acc))
print('time: %f'%(time.time()-tic))
logging.info('training acc at epoch %d: %s=%f'%(i, name, acc))
logging.info('time: %f'%(time.time()-tic))
test(ctx)

net.all_params().save('mnist.params')

if __name__ == '__main__':
net.hybridize()
train(200, [mx.gpu(i) for i in range(2)])
import logging
logging.basicConfig(level=logging.DEBUG)
data = mx.sym.var('data')
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(1)])
mod.fit(train_data, num_epoch=100, batch_end_callback = mx.callback.Speedometer(batch_size, 10))
if opt.symbolic:
data = mx.sym.var('data')
out = net(data)
softmax = mx.sym.SoftmaxOutput(out, name='softmax')
mod = mx.mod.Module(softmax, context=[mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()])
mod.fit(train_data, num_epoch=opt.epochs, batch_end_callback = mx.callback.Speedometer(batch_size, 1))
else:
net.hybridize()
train(opt.epochs, [mx.gpu(i) for i in range(gpus)] if gpus > 0 else [mx.cpu()])
2 changes: 1 addition & 1 deletion example/autograd/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mxnet.ndarray as F
from mxnet import foo
from mxnet.foo import nn
from mxnet.contrib import autograd as ag
from mxnet import autograd as ag
from mxnet.test_utils import download
from mxnet.image import CenterCropAug, ResizeAug
from mxnet.io import PrefetchingIter
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/foo/nn/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def register_child(self, layer):

def hybridize(self, active=True):
super(HybridLayer, self).hybridize(active)
self._active = True
self._active = active

def _get_graph(self, *args):
if self._cached_graph:
Expand Down

0 comments on commit a887e11

Please sign in to comment.