From d09c080ad9ad0a73e55f0f694a70e397ff96c91b Mon Sep 17 00:00:00 2001 From: Chiyuan Zhang Date: Thu, 6 Oct 2016 23:04:07 -0400 Subject: [PATCH] Stochastic Depth Training (#3468) * sanity check stochastic depth mnist * a cifar10 example (not tested) * add description for sd cifar * add doc for sd module * add a simple random number queue * add final numbers --- example/image-classification/train_cifar10.py | 8 +- example/module/sequential_module.py | 8 +- example/speech-demo/io_util.py | 13 +- example/stochastic-depth/sd_cifar10.py | 201 ++++++++++++++++++ example/stochastic-depth/sd_mnist.py | 86 ++++++++ example/stochastic-depth/sd_module.py | 172 +++++++++++++++ python/mxnet/model.py | 4 +- 7 files changed, 477 insertions(+), 15 deletions(-) create mode 100644 example/stochastic-depth/sd_cifar10.py create mode 100644 example/stochastic-depth/sd_mnist.py create mode 100644 example/stochastic-depth/sd_module.py diff --git a/example/image-classification/train_cifar10.py b/example/image-classification/train_cifar10.py index dc3580cd..b0c1c284 100644 --- a/example/image-classification/train_cifar10.py +++ b/example/image-classification/train_cifar10.py @@ -50,8 +50,7 @@ def _download(data_dir): net = importlib.import_module('symbol_' + args.network).get_symbol(10) # data -def get_iterator(args, kv): - data_shape = (3, 28, 28) +def get_iterator(args, kv, data_shape=(3, 28, 28)): if '://' not in args.data_dir: _download(args.data_dir) @@ -77,5 +76,6 @@ def get_iterator(args, kv): return (train, val) -# train -train_model.fit(args, net, get_iterator) +if __name__ == '__main__': + # train + train_model.fit(args, net, get_iterator) diff --git a/example/module/sequential_module.py b/example/module/sequential_module.py index def0558d..bc567af3 100644 --- a/example/module/sequential_module.py +++ b/example/module/sequential_module.py @@ -44,13 +44,13 @@ n_epoch = 2 batch_size = 100 train_dataiter = mx.io.MNISTIter( - image="data/train-images-idx3-ubyte", - label="data/train-labels-idx1-ubyte", + image="../image-classification/mnist/train-images-idx3-ubyte", + label="../image-classification/mnist/train-labels-idx1-ubyte", data_shape=(784,), batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) val_dataiter = mx.io.MNISTIter( - image="data/t10k-images-idx3-ubyte", - label="data/t10k-labels-idx1-ubyte", + image="../image-classification/mnist/t10k-images-idx3-ubyte", + label="../image-classification/mnist/t10k-labels-idx1-ubyte", data_shape=(784,), batch_size=batch_size, shuffle=True, flat=True, silent=False) diff --git a/example/speech-demo/io_util.py b/example/speech-demo/io_util.py index 5ef8bf21..6def02dc 100644 --- a/example/speech-demo/io_util.py +++ b/example/speech-demo/io_util.py @@ -68,19 +68,19 @@ def __init__(self, train_sets, batch_size, self.data_name = data_name if has_label: self.label_name = label_name - + features = [] labels = [] utt_lens = [] utt_ids = [] buckets = [] self.has_label = has_label - + if label_mean_sets is not None: self.label_mean_sets.initialize_read() (feats, tgts, utt_id) = self.label_mean_sets.load_next_seq() - self.label_mean = feats/np.sum(feats) + self.label_mean = feats/np.sum(feats) for i,v in enumerate(feats): if v <= 1.0: self.label_mean[i] = 1 @@ -103,7 +103,7 @@ def __init__(self, train_sets, batch_size, labels.append(tgts+1) if feats.shape[0] not in buckets: buckets_map[feats.shape[0]] = feats.shape[0] - + for k, v in buckets_map.iteritems(): buckets.append(k) @@ -116,13 +116,13 @@ def __init__(self, train_sets, batch_size, self.utt_lens = [[] for k in buckets] self.feat_dim = feat_dim self.default_bucket_key = max(buckets) - + for i, feats in enumerate(features): if has_label: tgts = labels[i] utt_len = utt_lens[i] utt_id = utt_ids[i] - + for i, bkt in enumerate(buckets): if bkt >= utt_len: i_bucket = i @@ -620,3 +620,4 @@ def __iter__(self): def reset(self): self.bucket_curr_idx = [0 for x in self.data] + diff --git a/example/stochastic-depth/sd_cifar10.py b/example/stochastic-depth/sd_cifar10.py new file mode 100644 index 00000000..995601d4 --- /dev/null +++ b/example/stochastic-depth/sd_cifar10.py @@ -0,0 +1,201 @@ +########################################################################################### +# Implementation of the stochastic depth algorithm described in the paper +# +# Huang, Gao, et al. "Deep networks with stochastic depth." arXiv preprint arXiv:1603.09382 (2016). +# +# Reference torch implementation can be found at https://github.com/yueatsprograms/Stochastic_Depth +# +# There are some differences in the implementation: +# - A BN->ReLU->Conv is used for skip connection when input and output shapes are different, +# as oppose to a padding layer. +# - The residual block is different: we use BN->ReLU->Conv->BN->ReLU->Conv, as oppose to +# Conv->BN->ReLU->Conv->BN (->ReLU also applied to skip connection). +# - We did not try to match with the same initialization, learning rate scheduling, etc. +# +#-------------------------------------------------------------------------------- +# A sample from the running log (We achieved ~9.4% error after 500 epochs, some +# more careful tuning of the hyper parameters and maybe also the arch is needed +# to achieve the reported numbers in the paper): +# +# INFO:root:Epoch[80] Batch [50] Speed: 1020.95 samples/sec Train-accuracy=0.910080 +# INFO:root:Epoch[80] Batch [100] Speed: 1013.41 samples/sec Train-accuracy=0.912031 +# INFO:root:Epoch[80] Batch [150] Speed: 1035.48 samples/sec Train-accuracy=0.913438 +# INFO:root:Epoch[80] Batch [200] Speed: 1045.00 samples/sec Train-accuracy=0.907344 +# INFO:root:Epoch[80] Batch [250] Speed: 1055.32 samples/sec Train-accuracy=0.905937 +# INFO:root:Epoch[80] Batch [300] Speed: 1071.71 samples/sec Train-accuracy=0.912500 +# INFO:root:Epoch[80] Batch [350] Speed: 1033.73 samples/sec Train-accuracy=0.910937 +# INFO:root:Epoch[80] Train-accuracy=0.919922 +# INFO:root:Epoch[80] Time cost=48.348 +# INFO:root:Saved checkpoint to "sd-110-0081.params" +# INFO:root:Epoch[80] Validation-accuracy=0.880142 +# ... +# INFO:root:Epoch[115] Batch [50] Speed: 1037.04 samples/sec Train-accuracy=0.937040 +# INFO:root:Epoch[115] Batch [100] Speed: 1041.12 samples/sec Train-accuracy=0.934219 +# INFO:root:Epoch[115] Batch [150] Speed: 1036.02 samples/sec Train-accuracy=0.933125 +# INFO:root:Epoch[115] Batch [200] Speed: 1057.49 samples/sec Train-accuracy=0.938125 +# INFO:root:Epoch[115] Batch [250] Speed: 1060.56 samples/sec Train-accuracy=0.933438 +# INFO:root:Epoch[115] Batch [300] Speed: 1046.25 samples/sec Train-accuracy=0.935625 +# INFO:root:Epoch[115] Batch [350] Speed: 1043.83 samples/sec Train-accuracy=0.927188 +# INFO:root:Epoch[115] Train-accuracy=0.938477 +# INFO:root:Epoch[115] Time cost=47.815 +# INFO:root:Saved checkpoint to "sd-110-0116.params" +# INFO:root:Epoch[115] Validation-accuracy=0.884415 +# ... +# INFO:root:Saved checkpoint to "sd-110-0499.params" +# INFO:root:Epoch[498] Validation-accuracy=0.908554 +# INFO:root:Epoch[499] Batch [50] Speed: 1068.28 samples/sec Train-accuracy=0.991422 +# INFO:root:Epoch[499] Batch [100] Speed: 1053.10 samples/sec Train-accuracy=0.991094 +# INFO:root:Epoch[499] Batch [150] Speed: 1042.89 samples/sec Train-accuracy=0.995156 +# INFO:root:Epoch[499] Batch [200] Speed: 1066.22 samples/sec Train-accuracy=0.991406 +# INFO:root:Epoch[499] Batch [250] Speed: 1050.56 samples/sec Train-accuracy=0.990781 +# INFO:root:Epoch[499] Batch [300] Speed: 1032.02 samples/sec Train-accuracy=0.992500 +# INFO:root:Epoch[499] Batch [350] Speed: 1062.16 samples/sec Train-accuracy=0.992969 +# INFO:root:Epoch[499] Train-accuracy=0.994141 +# INFO:root:Epoch[499] Time cost=47.401 +# INFO:root:Saved checkpoint to "sd-110-0500.params" +# INFO:root:Epoch[499] Validation-accuracy=0.906050 +# ########################################################################################### + +import os +import sys +import mxnet as mx +import logging + +import sd_module + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification"))) +from train_cifar10 import get_iterator + + +def residual_module(death_rate, n_channel, name_scope, context, stride=1, bn_momentum=0.9): + data = mx.sym.Variable(name_scope + '_data') + + # computation branch: + # BN -> ReLU -> Conv -> BN -> ReLU -> Conv + bn1 = mx.symbol.BatchNorm(data=data, name=name_scope + '_bn1', fix_gamma=False, + momentum=bn_momentum, + # Same with https://github.com/soumith/cudnn.torch/blob/master/BatchNormalization.lua + # cuDNN v5 don't allow a small eps of 1e-5 + eps=2e-5 + ) + relu1 = mx.symbol.Activation(data=bn1, act_type='relu', name=name_scope+'_relu1') + conv1 = mx.symbol.Convolution(data=relu1, num_filter=n_channel, kernel=(3, 3), pad=(1,1), + stride=(stride, stride), name=name_scope+'_conv1') + bn2 = mx.symbol.BatchNorm(data=conv1, fix_gamma=False, momentum=bn_momentum, + eps=2e-5, name=name_scope+'_bn2') + relu2 = mx.symbol.Activation(data=bn2, act_type='relu', name=name_scope+'_relu2') + conv2 = mx.symbol.Convolution(data=relu2, num_filter=n_channel, kernel=(3, 3), pad=(1,1), + stride=(1, 1), name=name_scope+'_conv2') + sym_compute = conv2 + + # skip branch + if stride > 1: + sym_skip = mx.symbol.BatchNorm(data=data, fix_gamma=False, momentum=bn_momentum, + eps=2e-5, name=name_scope+'_skip_bn') + sym_skip = mx.symbol.Activation(data=sym_skip, act_type='relu', name=name_scope+'_skip_relu') + sym_skip = mx.symbol.Convolution(data=sym_skip, num_filter=n_channel, kernel=(3, 3), pad=(1, 1), + stride=(stride, stride), name=name_scope+'_skip_conv') + else: + sym_skip = None + + mod = sd_module.StochasticDepthModule(sym_compute, sym_skip, data_names=[name_scope+'_data'], + context=context, death_rate=death_rate) + return mod + + +################################################################################# +# Build architecture +# Configurations +bn_momentum = 0.9 +contexts = [mx.context.gpu(i) for i in range(1)] +n_residual_blocks = 18 +death_rate = 0.5 +death_mode = 'linear_decay' # 'linear_decay' or 'uniform' + +n_classes = 10 + +def get_death_rate(i_res_block): + n_total_res_blocks = n_residual_blocks * 3 + if death_mode == 'linear_decay': + my_death_rate = float(i_res_block) / n_total_res_blocks * death_rate + else: + my_death_rate = death_rate + return my_death_rate + +# 0. base ConvNet +sym_base = mx.sym.Variable('data') +sym_base = mx.sym.Convolution(data=sym_base, num_filter=16, kernel=(3, 3), pad=(1, 1), name='conv1') +sym_base = mx.sym.BatchNorm(data=sym_base, name='bn1', fix_gamma=False, momentum=bn_momentum, eps=2e-5) +sym_base = mx.sym.Activation(data=sym_base, name='relu1', act_type='relu') +mod_base = mx.mod.Module(sym_base, context=contexts, label_names=None) + +# 1. container +mod_seq = mx.mod.SequentialModule() +mod_seq.add(mod_base) + +# 2. first group, 16 x 28 x 28 +i_res_block = 0 +for i in range(n_residual_blocks): + mod_seq.add(residual_module(get_death_rate(i_res_block), 16, 'res_A_%d' % i, contexts), auto_wiring=True) + i_res_block += 1 + +# 3. second group, 32 x 14 x 14 +mod_seq.add(residual_module(get_death_rate(i_res_block), 32, 'res_AB', contexts, stride=2), auto_wiring=True) +i_res_block += 1 + +for i in range(n_residual_blocks-1): + mod_seq.add(residual_module(get_death_rate(i_res_block), 32, 'res_B_%d' % i, contexts), auto_wiring=True) + i_res_block += 1 + +# 4. third group, 64 x 7 x 7 +mod_seq.add(residual_module(get_death_rate(i_res_block), 64, 'res_BC', contexts, stride=2), auto_wiring=True) +i_res_block += 1 + +for i in range(n_residual_blocks-1): + mod_seq.add(residual_module(get_death_rate(i_res_block), 64, 'res_C_%d' % i, contexts), auto_wiring=True) + i_res_block += 1 + +# 5. final module +sym_final = mx.sym.Variable('data') +sym_final = mx.sym.Pooling(data=sym_final, kernel=(7, 7), pool_type='avg', name='global_pool') +sym_final = mx.sym.FullyConnected(data=sym_final, num_hidden=n_classes, name='logits') +sym_final = mx.sym.SoftmaxOutput(data=sym_final, name='softmax') +mod_final = mx.mod.Module(sym_final, context=contexts) +mod_seq.add(mod_final, auto_wiring=True, take_labels=True) + + +################################################################################# +# Training +num_examples = 60000 +batch_size = 128 +base_lr = 0.008 +lr_factor = 0.5 +lr_factor_epoch = 100 +momentum = 0.9 +weight_decay = 0.00001 +kv_store = 'local' + +initializer = mx.init.Xavier(factor_type="in", magnitude=2.34) +num_epochs = 500 + +epoch_size = num_examples / batch_size +lr_scheduler = mx.lr_scheduler.FactorScheduler(step=max(int(epoch_size * lr_factor_epoch), 1), factor=lr_factor) + +batch_end_callbacks = [mx.callback.Speedometer(batch_size, 50)] +epoch_end_callbacks = [mx.callback.do_checkpoint('sd-%d' % (n_residual_blocks * 6 + 2))] + + +args = type('', (), {})() +args.batch_size = batch_size +args.data_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification", "cifar10")) + '/' +kv = mx.kvstore.create(kv_store) +train, val = get_iterator(args, kv) + +logging.basicConfig(level=logging.DEBUG) +mod_seq.fit(train, val, + optimizer_params={'learning_rate': base_lr, 'momentum': momentum, + 'lr_scheduler': lr_scheduler, 'wd': weight_decay}, + num_epoch=num_epochs, batch_end_callback=batch_end_callbacks, + epoch_end_callback=epoch_end_callbacks, + initializer=initializer) + diff --git a/example/stochastic-depth/sd_mnist.py b/example/stochastic-depth/sd_mnist.py new file mode 100644 index 00000000..66529a27 --- /dev/null +++ b/example/stochastic-depth/sd_mnist.py @@ -0,0 +1,86 @@ +################################################################################ +# A sanity check mainly for debugging purpose. See sd_cifar10.py for a non-trivial +# example of stochastic depth on cifar10. +################################################################################ + +import os +import sys +import mxnet as mx +import logging + +import sd_module + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "image-classification"))) +from train_mnist import get_iterator +from symbol_resnet import get_conv + +death_rates = [0.3] +contexts = [mx.context.cpu()] + +data = mx.symbol.Variable('data') +conv = get_conv( + name='conv0', + data=data, + num_filter=16, + kernel=(3, 3), + stride=(1, 1), + pad=(1, 1), + with_relu=True, + bn_momentum=0.9 +) + +base_mod = mx.mod.Module(conv, label_names=None, context=contexts) +mod_seq = mx.mod.SequentialModule() +mod_seq.add(base_mod) + +for i in range(len(death_rates)): + conv = get_conv( + name='conv0_%d' % i, + data=mx.sym.Variable('data_%d' % i), + num_filter=16, + kernel=(3, 3), + stride=(1, 1), + pad=(1, 1), + with_relu=True, + bn_momentum=0.9 + ) + conv = get_conv( + name='conv1_%d' % i, + data=conv, + num_filter=16, + kernel=(3, 3), + stride=(1, 1), + pad=(1, 1), + with_relu=False, + bn_momentum=0.9 + ) + mod = sd_module.StochasticDepthModule(conv, data_names=['data_%d' % i], + context=contexts, death_rate=death_rates[i]) + mod_seq.add(mod, auto_wiring=True) + +act = mx.sym.Activation(mx.sym.Variable('data_final'), act_type='relu') +flat = mx.sym.Flatten(act) +pred = mx.sym.FullyConnected(flat, num_hidden=10) +softmax = mx.sym.SoftmaxOutput(pred, name='softmax') +mod_seq.add(mx.mod.Module(softmax, context=contexts, data_names=['data_final']), + auto_wiring=True, take_labels=True) + + +n_epoch = 2 +batch_size = 100 + + +train = mx.io.MNISTIter( + image="../image-classification/mnist/train-images-idx3-ubyte", + label="../image-classification/mnist/train-labels-idx1-ubyte", + input_shape=(1, 28, 28), flat=False, + batch_size=batch_size, shuffle=True, silent=False, seed=10) +val = mx.io.MNISTIter( + image="../image-classification/mnist/t10k-images-idx3-ubyte", + label="../image-classification/mnist/t10k-labels-idx1-ubyte", + input_shape=(1, 28, 28), flat=False, + batch_size=batch_size, shuffle=True, silent=False) + +logging.basicConfig(level=logging.DEBUG) +mod_seq.fit(train, val, optimizer_params={'learning_rate': 0.01, 'momentum': 0.9}, + num_epoch=n_epoch, batch_end_callback=mx.callback.Speedometer(batch_size, 10)) diff --git a/example/stochastic-depth/sd_module.py b/example/stochastic-depth/sd_module.py new file mode 100644 index 00000000..ae8cfe0b --- /dev/null +++ b/example/stochastic-depth/sd_module.py @@ -0,0 +1,172 @@ +import logging +import mxnet as mx +import numpy as np + + +class RandomNumberQueue(object): + def __init__(self, pool_size=1000): + self._pool = np.random.rand(pool_size) + self._index = 0 + + def get_sample(self): + if self._index >= len(self._pool): + self._pool = np.random.rand(len(self._pool)) + self._index = 0 + self._index += 1 + return self._pool[self._index-1] + + +class StochasticDepthModule(mx.module.BaseModule): + """Stochastic depth module is a two branch computation: one is actual computing and the + other is the skip computing (usually an identity map). This is similar to a Residual block, + except that a random variable is used to randomly turn off the computing branch, in order + to save computation during training. + + Parameters + ---------- + symbol_compute: Symbol + The computation branch. + symbol_skip: Symbol + The skip branch. Could be None, in which case an identity map will be automatically + used. Note the two branch should produce exactly the same output shapes. + data_names: list of str + Default is `['data']`. Indicating the input names. Note if `symbol_skip` is not None, + it should have the same input names as `symbol_compute`. + label_names: list of str + Default is None, indicating that this module does not take labels. + death_rate: float + Default 0. The probability of turning off the computing branch. + """ + def __init__(self, symbol_compute, symbol_skip=None, + data_names=('data',), label_names=None, + logger=logging, context=mx.context.cpu(), + work_load_list=None, fixed_param_names=None, + death_rate=0): + super(StochasticDepthModule, self).__init__(logger=logger) + + self._module_compute = mx.module.Module( + symbol_compute, data_names=data_names, + label_names=label_names, logger=logger, + context=context, work_load_list=work_load_list, + fixed_param_names=fixed_param_names) + + if symbol_skip is not None: + self._module_skip = mx.module.Module( + symbol_skip, data_names=data_names, + label_names=label_names, logger=logger, + context=context, work_load_list=work_load_list, + fixed_param_names=fixed_param_names) + else: + self._module_skip = None + + self._open_rate = 1 - death_rate + self._gate_open = True + self._outputs = None + self._input_grads = None + self._rnd_queue = RandomNumberQueue() + + @property + def data_names(self): + return self._module_compute.data_names + + @property + def output_names(self): + return self._module_compute.output_names + + @property + def data_shapes(self): + return self._module_compute.data_shapes + + @property + def label_shapes(self): + return self._module_compute.label_shapes + + @property + def output_shapes(self): + return self._module_compute.output_shapes + + def get_params(self): + params = self._module_compute.get_params() + if self._module_skip: + params = [x.copy() for x in params] + skip_params = self._module_skip.get_params() + for a, b in zip(params, skip_params): + # make sure they do not contain duplicated param names + assert len(set(a.keys()) & set(b.keys())) == 0 + a.update(b) + return params + + def init_params(self, *args, **kwargs): + self._module_compute.init_params(*args, **kwargs) + if self._module_skip: + self._module_skip.init_params(*args, **kwargs) + + def bind(self, *args, **kwargs): + self._module_compute.bind(*args, **kwargs) + if self._module_skip: + self._module_skip.bind(*args, **kwargs) + + def init_optimizer(self, *args, **kwargs): + self._module_compute.init_optimizer(*args, **kwargs) + if self._module_skip: + self._module_skip.init_optimizer(*args, **kwargs) + + def borrow_optimizer(self, shared_module): + self._module_compute.borrow_optimizer(shared_module._module_compute) + if self._module_skip: + self._module_skip.borrow_optimizer(shared_module._module_skip) + + def forward(self, data_batch, is_train=None): + if is_train is None: + is_train = self._module_compute.for_training + + if self._module_skip: + self._module_skip.forward(data_batch, is_train=True) + self._outputs = self._module_skip.get_outputs() + else: + self._outputs = data_batch.data + + if is_train: + self._gate_open = self._rnd_queue.get_sample() < self._open_rate + if self._gate_open: + self._module_compute.forward(data_batch, is_train=True) + computed_outputs = self._module_compute.get_outputs() + for i in range(len(self._outputs)): + self._outputs[i] += computed_outputs[i] + + else: # do expectation for prediction + self._module_compute.forward(data_batch, is_train=False) + computed_outputs = self._module_compute.get_outputs() + for i in range(len(self._outputs)): + self._outputs[i] += self._open_rate * computed_outputs[i] + + def backward(self, out_grads=None): + if self._module_skip: + self._module_skip.backward(out_grads=out_grads) + self._input_grads = self._module_skip.get_input_grads() + else: + self._input_grads = out_grads + + if self._gate_open: + self._module_compute.backward(out_grads=out_grads) + computed_input_grads = self._module_compute.get_input_grads() + for i in range(len(self._input_grads)): + self._input_grads[i] += computed_input_grads[i] + + def update(self): + self._module_compute.update() + if self._module_skip: + self._module_skip.update() + + def update_metric(self, eval_metric, labels): + self._module_compute.update_metric(eval_metric, labels) + if self._module_skip: + self._module_skip.update_metric(eval_metric, labels) + + def get_outputs(self, merge_multi_context=True): + assert merge_multi_context, "Force merging for now" + return self._outputs + + def get_input_grads(self, merge_multi_context=True): + assert merge_multi_context, "Force merging for now" + return self._input_grads diff --git a/python/mxnet/model.py b/python/mxnet/model.py index bd215018..f98d48af 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -325,7 +325,9 @@ def save_checkpoint(prefix, epoch, symbol, arg_params, aux_params): - ``prefix-symbol.json`` will be saved for symbol. - ``prefix-epoch.params`` will be saved for parameters. """ - symbol.save('%s-symbol.json' % prefix) + if symbol is not None: + symbol.save('%s-symbol.json' % prefix) + save_dict = {('arg:%s' % k) : v for k, v in arg_params.items()} save_dict.update({('aux:%s' % k) : v for k, v in aux_params.items()}) param_name = '%s-%04d.params' % (prefix, epoch)