Skip to content

Commit

Permalink
Stochastic Depth Training (#3468)
Browse files Browse the repository at this point in the history
* sanity check stochastic depth mnist

* a cifar10 example (not tested)

* add description for sd cifar

* add doc for sd module

* add a simple random number queue

* add final numbers
  • Loading branch information
pluskid authored and piiswrong committed Oct 7, 2016
1 parent a5aeb0c commit d09c080
Show file tree
Hide file tree
Showing 7 changed files with 477 additions and 15 deletions.
8 changes: 4 additions & 4 deletions example/image-classification/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def _download(data_dir):
net = importlib.import_module('symbol_' + args.network).get_symbol(10)

# data
def get_iterator(args, kv):
data_shape = (3, 28, 28)
def get_iterator(args, kv, data_shape=(3, 28, 28)):
if '://' not in args.data_dir:
_download(args.data_dir)

Expand All @@ -77,5 +76,6 @@ def get_iterator(args, kv):

return (train, val)

# train
train_model.fit(args, net, get_iterator)
if __name__ == '__main__':
# train
train_model.fit(args, net, get_iterator)
8 changes: 4 additions & 4 deletions example/module/sequential_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@
n_epoch = 2
batch_size = 100
train_dataiter = mx.io.MNISTIter(
image="data/train-images-idx3-ubyte",
label="data/train-labels-idx1-ubyte",
image="../image-classification/mnist/train-images-idx3-ubyte",
label="../image-classification/mnist/train-labels-idx1-ubyte",
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="data/t10k-images-idx3-ubyte",
label="data/t10k-labels-idx1-ubyte",
image="../image-classification/mnist/t10k-images-idx3-ubyte",
label="../image-classification/mnist/t10k-labels-idx1-ubyte",
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)

Expand Down
13 changes: 7 additions & 6 deletions example/speech-demo/io_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ def __init__(self, train_sets, batch_size,
self.data_name = data_name
if has_label:
self.label_name = label_name

features = []
labels = []
utt_lens = []
utt_ids = []
buckets = []
self.has_label = has_label

if label_mean_sets is not None:
self.label_mean_sets.initialize_read()
(feats, tgts, utt_id) = self.label_mean_sets.load_next_seq()

self.label_mean = feats/np.sum(feats)
self.label_mean = feats/np.sum(feats)
for i,v in enumerate(feats):
if v <= 1.0:
self.label_mean[i] = 1
Expand All @@ -103,7 +103,7 @@ def __init__(self, train_sets, batch_size,
labels.append(tgts+1)
if feats.shape[0] not in buckets:
buckets_map[feats.shape[0]] = feats.shape[0]

for k, v in buckets_map.iteritems():
buckets.append(k)

Expand All @@ -116,13 +116,13 @@ def __init__(self, train_sets, batch_size,
self.utt_lens = [[] for k in buckets]
self.feat_dim = feat_dim
self.default_bucket_key = max(buckets)

for i, feats in enumerate(features):
if has_label:
tgts = labels[i]
utt_len = utt_lens[i]
utt_id = utt_ids[i]

for i, bkt in enumerate(buckets):
if bkt >= utt_len:
i_bucket = i
Expand Down Expand Up @@ -620,3 +620,4 @@ def __iter__(self):

def reset(self):
self.bucket_curr_idx = [0 for x in self.data]

201 changes: 201 additions & 0 deletions example/stochastic-depth/sd_cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
###########################################################################################
# Implementation of the stochastic depth algorithm described in the paper
#
# Huang, Gao, et al. "Deep networks with stochastic depth." arXiv preprint arXiv:1603.09382 (2016).
#
# Reference torch implementation can be found at https://github.com/yueatsprograms/Stochastic_Depth
#
# There are some differences in the implementation:
# - A BN->ReLU->Conv is used for skip connection when input and output shapes are different,
# as oppose to a padding layer.
# - The residual block is different: we use BN->ReLU->Conv->BN->ReLU->Conv, as oppose to
# Conv->BN->ReLU->Conv->BN (->ReLU also applied to skip connection).
# - We did not try to match with the same initialization, learning rate scheduling, etc.
#
#--------------------------------------------------------------------------------
# A sample from the running log (We achieved ~9.4% error after 500 epochs, some
# more careful tuning of the hyper parameters and maybe also the arch is needed
# to achieve the reported numbers in the paper):
#
# INFO:root:Epoch[80] Batch [50] Speed: 1020.95 samples/sec Train-accuracy=0.910080
# INFO:root:Epoch[80] Batch [100] Speed: 1013.41 samples/sec Train-accuracy=0.912031
# INFO:root:Epoch[80] Batch [150] Speed: 1035.48 samples/sec Train-accuracy=0.913438
# INFO:root:Epoch[80] Batch [200] Speed: 1045.00 samples/sec Train-accuracy=0.907344
# INFO:root:Epoch[80] Batch [250] Speed: 1055.32 samples/sec Train-accuracy=0.905937
# INFO:root:Epoch[80] Batch [300] Speed: 1071.71 samples/sec Train-accuracy=0.912500
# INFO:root:Epoch[80] Batch [350] Speed: 1033.73 samples/sec Train-accuracy=0.910937
# INFO:root:Epoch[80] Train-accuracy=0.919922
# INFO:root:Epoch[80] Time cost=48.348
# INFO:root:Saved checkpoint to "sd-110-0081.params"
# INFO:root:Epoch[80] Validation-accuracy=0.880142
# ...
# INFO:root:Epoch[115] Batch [50] Speed: 1037.04 samples/sec Train-accuracy=0.937040
# INFO:root:Epoch[115] Batch [100] Speed: 1041.12 samples/sec Train-accuracy=0.934219
# INFO:root:Epoch[115] Batch [150] Speed: 1036.02 samples/sec Train-accuracy=0.933125
# INFO:root:Epoch[115] Batch [200] Speed: 1057.49 samples/sec Train-accuracy=0.938125
# INFO:root:Epoch[115] Batch [250] Speed: 1060.56 samples/sec Train-accuracy=0.933438
# INFO:root:Epoch[115] Batch [300] Speed: 1046.25 samples/sec Train-accuracy=0.935625
# INFO:root:Epoch[115] Batch [350] Speed: 1043.83 samples/sec Train-accuracy=0.927188
# INFO:root:Epoch[115] Train-accuracy=0.938477
# INFO:root:Epoch[115] Time cost=47.815
# INFO:root:Saved checkpoint to "sd-110-0116.params"
# INFO:root:Epoch[115] Validation-accuracy=0.884415
# ...
# INFO:root:Saved checkpoint to "sd-110-0499.params"
# INFO:root:Epoch[498] Validation-accuracy=0.908554
# INFO:root:Epoch[499] Batch [50] Speed: 1068.28 samples/sec Train-accuracy=0.991422
# INFO:root:Epoch[499] Batch [100] Speed: 1053.10 samples/sec Train-accuracy=0.991094
# INFO:root:Epoch[499] Batch [150] Speed: 1042.89 samples/sec Train-accuracy=0.995156
# INFO:root:Epoch[499] Batch [200] Speed: 1066.22 samples/sec Train-accuracy=0.991406
# INFO:root:Epoch[499] Batch [250] Speed: 1050.56 samples/sec Train-accuracy=0.990781
# INFO:root:Epoch[499] Batch [300] Speed: 1032.02 samples/sec Train-accuracy=0.992500
# INFO:root:Epoch[499] Batch [350] Speed: 1062.16 samples/sec Train-accuracy=0.992969
# INFO:root:Epoch[499] Train-accuracy=0.994141
# INFO:root:Epoch[499] Time cost=47.401
# INFO:root:Saved checkpoint to "sd-110-0500.params"
# INFO:root:Epoch[499] Validation-accuracy=0.906050
# ###########################################################################################

import os
import sys
import mxnet as mx
import logging

import sd_module

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification")))
from train_cifar10 import get_iterator


def residual_module(death_rate, n_channel, name_scope, context, stride=1, bn_momentum=0.9):
data = mx.sym.Variable(name_scope + '_data')

# computation branch:
# BN -> ReLU -> Conv -> BN -> ReLU -> Conv
bn1 = mx.symbol.BatchNorm(data=data, name=name_scope + '_bn1', fix_gamma=False,
momentum=bn_momentum,
# Same with https://github.com/soumith/cudnn.torch/blob/master/BatchNormalization.lua
# cuDNN v5 don't allow a small eps of 1e-5
eps=2e-5
)
relu1 = mx.symbol.Activation(data=bn1, act_type='relu', name=name_scope+'_relu1')
conv1 = mx.symbol.Convolution(data=relu1, num_filter=n_channel, kernel=(3, 3), pad=(1,1),
stride=(stride, stride), name=name_scope+'_conv1')
bn2 = mx.symbol.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_momentum,
eps=2e-5, name=name_scope+'_bn2')
relu2 = mx.symbol.Activation(data=bn2, act_type='relu', name=name_scope+'_relu2')
conv2 = mx.symbol.Convolution(data=relu2, num_filter=n_channel, kernel=(3, 3), pad=(1,1),
stride=(1, 1), name=name_scope+'_conv2')
sym_compute = conv2

# skip branch
if stride > 1:
sym_skip = mx.symbol.BatchNorm(data=data, fix_gamma=False, momentum=bn_momentum,
eps=2e-5, name=name_scope+'_skip_bn')
sym_skip = mx.symbol.Activation(data=sym_skip, act_type='relu', name=name_scope+'_skip_relu')
sym_skip = mx.symbol.Convolution(data=sym_skip, num_filter=n_channel, kernel=(3, 3), pad=(1, 1),
stride=(stride, stride), name=name_scope+'_skip_conv')
else:
sym_skip = None

mod = sd_module.StochasticDepthModule(sym_compute, sym_skip, data_names=[name_scope+'_data'],
context=context, death_rate=death_rate)
return mod


#################################################################################
# Build architecture
# Configurations
bn_momentum = 0.9
contexts = [mx.context.gpu(i) for i in range(1)]
n_residual_blocks = 18
death_rate = 0.5
death_mode = 'linear_decay' # 'linear_decay' or 'uniform'

n_classes = 10

def get_death_rate(i_res_block):
n_total_res_blocks = n_residual_blocks * 3
if death_mode == 'linear_decay':
my_death_rate = float(i_res_block) / n_total_res_blocks * death_rate
else:
my_death_rate = death_rate
return my_death_rate

# 0. base ConvNet
sym_base = mx.sym.Variable('data')
sym_base = mx.sym.Convolution(data=sym_base, num_filter=16, kernel=(3, 3), pad=(1, 1), name='conv1')
sym_base = mx.sym.BatchNorm(data=sym_base, name='bn1', fix_gamma=False, momentum=bn_momentum, eps=2e-5)
sym_base = mx.sym.Activation(data=sym_base, name='relu1', act_type='relu')
mod_base = mx.mod.Module(sym_base, context=contexts, label_names=None)

# 1. container
mod_seq = mx.mod.SequentialModule()
mod_seq.add(mod_base)

# 2. first group, 16 x 28 x 28
i_res_block = 0
for i in range(n_residual_blocks):
mod_seq.add(residual_module(get_death_rate(i_res_block), 16, 'res_A_%d' % i, contexts), auto_wiring=True)
i_res_block += 1

# 3. second group, 32 x 14 x 14
mod_seq.add(residual_module(get_death_rate(i_res_block), 32, 'res_AB', contexts, stride=2), auto_wiring=True)
i_res_block += 1

for i in range(n_residual_blocks-1):
mod_seq.add(residual_module(get_death_rate(i_res_block), 32, 'res_B_%d' % i, contexts), auto_wiring=True)
i_res_block += 1

# 4. third group, 64 x 7 x 7
mod_seq.add(residual_module(get_death_rate(i_res_block), 64, 'res_BC', contexts, stride=2), auto_wiring=True)
i_res_block += 1

for i in range(n_residual_blocks-1):
mod_seq.add(residual_module(get_death_rate(i_res_block), 64, 'res_C_%d' % i, contexts), auto_wiring=True)
i_res_block += 1

# 5. final module
sym_final = mx.sym.Variable('data')
sym_final = mx.sym.Pooling(data=sym_final, kernel=(7, 7), pool_type='avg', name='global_pool')
sym_final = mx.sym.FullyConnected(data=sym_final, num_hidden=n_classes, name='logits')
sym_final = mx.sym.SoftmaxOutput(data=sym_final, name='softmax')
mod_final = mx.mod.Module(sym_final, context=contexts)
mod_seq.add(mod_final, auto_wiring=True, take_labels=True)


#################################################################################
# Training
num_examples = 60000
batch_size = 128
base_lr = 0.008
lr_factor = 0.5
lr_factor_epoch = 100
momentum = 0.9
weight_decay = 0.00001
kv_store = 'local'

initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
num_epochs = 500

epoch_size = num_examples / batch_size
lr_scheduler = mx.lr_scheduler.FactorScheduler(step=max(int(epoch_size * lr_factor_epoch), 1), factor=lr_factor)

batch_end_callbacks = [mx.callback.Speedometer(batch_size, 50)]
epoch_end_callbacks = [mx.callback.do_checkpoint('sd-%d' % (n_residual_blocks * 6 + 2))]


args = type('', (), {})()
args.batch_size = batch_size
args.data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification", "cifar10")) + '/'
kv = mx.kvstore.create(kv_store)
train, val = get_iterator(args, kv)

logging.basicConfig(level=logging.DEBUG)
mod_seq.fit(train, val,
optimizer_params={'learning_rate': base_lr, 'momentum': momentum,
'lr_scheduler': lr_scheduler, 'wd': weight_decay},
num_epoch=num_epochs, batch_end_callback=batch_end_callbacks,
epoch_end_callback=epoch_end_callbacks,
initializer=initializer)

86 changes: 86 additions & 0 deletions example/stochastic-depth/sd_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
################################################################################
# A sanity check mainly for debugging purpose. See sd_cifar10.py for a non-trivial
# example of stochastic depth on cifar10.
################################################################################

import os
import sys
import mxnet as mx
import logging

import sd_module

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification")))
from train_mnist import get_iterator
from symbol_resnet import get_conv

death_rates = [0.3]
contexts = [mx.context.cpu()]

data = mx.symbol.Variable('data')
conv = get_conv(
name='conv0',
data=data,
num_filter=16,
kernel=(3, 3),
stride=(1, 1),
pad=(1, 1),
with_relu=True,
bn_momentum=0.9
)

base_mod = mx.mod.Module(conv, label_names=None, context=contexts)
mod_seq = mx.mod.SequentialModule()
mod_seq.add(base_mod)

for i in range(len(death_rates)):
conv = get_conv(
name='conv0_%d' % i,
data=mx.sym.Variable('data_%d' % i),
num_filter=16,
kernel=(3, 3),
stride=(1, 1),
pad=(1, 1),
with_relu=True,
bn_momentum=0.9
)
conv = get_conv(
name='conv1_%d' % i,
data=conv,
num_filter=16,
kernel=(3, 3),
stride=(1, 1),
pad=(1, 1),
with_relu=False,
bn_momentum=0.9
)
mod = sd_module.StochasticDepthModule(conv, data_names=['data_%d' % i],
context=contexts, death_rate=death_rates[i])
mod_seq.add(mod, auto_wiring=True)

act = mx.sym.Activation(mx.sym.Variable('data_final'), act_type='relu')
flat = mx.sym.Flatten(act)
pred = mx.sym.FullyConnected(flat, num_hidden=10)
softmax = mx.sym.SoftmaxOutput(pred, name='softmax')
mod_seq.add(mx.mod.Module(softmax, context=contexts, data_names=['data_final']),
auto_wiring=True, take_labels=True)


n_epoch = 2
batch_size = 100


train = mx.io.MNISTIter(
image="../image-classification/mnist/train-images-idx3-ubyte",
label="../image-classification/mnist/train-labels-idx1-ubyte",
input_shape=(1, 28, 28), flat=False,
batch_size=batch_size, shuffle=True, silent=False, seed=10)
val = mx.io.MNISTIter(
image="../image-classification/mnist/t10k-images-idx3-ubyte",
label="../image-classification/mnist/t10k-labels-idx1-ubyte",
input_shape=(1, 28, 28), flat=False,
batch_size=batch_size, shuffle=True, silent=False)

logging.basicConfig(level=logging.DEBUG)
mod_seq.fit(train, val, optimizer_params={'learning_rate': 0.01, 'momentum': 0.9},
num_epoch=n_epoch, batch_end_callback=mx.callback.Speedometer(batch_size, 10))
Loading

0 comments on commit d09c080

Please sign in to comment.