-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/master'
- Loading branch information
Showing
7 changed files
with
568 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Mulit-task learning example | ||
|
||
This is a simple example to show how to use mxnet for multi-task learning. It uses MNIST as an example and mocks up the multi-label task. | ||
|
||
## Usage | ||
First, you need to write a multi-task iterator on your own. The iterator needs to generate multiple labels according to your applications, and the label names should be specified in the `provide_label` function, which needs to be consist with the names of output layers. | ||
|
||
Then, if you want to show metrics of different tasks separately, you need to write your own metric class and specify the `num` parameter. In the `update` function of metric, calculate the metrics seperately for different tasks. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# pylint: skip-file | ||
""" data iterator for mnist """ | ||
import sys | ||
import os | ||
# code to automatically download dataset | ||
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) | ||
sys.path.append(os.path.join(curr_path, "../../tests/python/common")) | ||
import get_data | ||
import mxnet as mx | ||
|
||
def mnist_iterator(batch_size, input_shape): | ||
"""return train and val iterators for mnist""" | ||
# download data | ||
get_data.GetMNIST_ubyte() | ||
flat = False if len(input_shape) == 3 else True | ||
|
||
train_dataiter = mx.io.MNISTIter( | ||
image="data/train-images-idx3-ubyte", | ||
label="data/train-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
flat=flat) | ||
|
||
val_dataiter = mx.io.MNISTIter( | ||
image="data/t10k-images-idx3-ubyte", | ||
label="data/t10k-labels-idx1-ubyte", | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
flat=flat) | ||
|
||
return (train_dataiter, val_dataiter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# pylint: skip-file | ||
import sys | ||
sys.path.insert(0, "../../python/") | ||
from data import mnist_iterator | ||
import mxnet as mx | ||
import numpy as np | ||
import logging | ||
import time | ||
|
||
logging.basicConfig(level=logging.DEBUG) | ||
|
||
def build_network(): | ||
data = mx.symbol.Variable('data') | ||
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) | ||
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") | ||
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) | ||
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") | ||
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) | ||
sm1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax1') | ||
sm2 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax2') | ||
|
||
softmax = mx.symbol.Group([sm1, sm2]) | ||
|
||
return softmax | ||
|
||
class Multi_mnist_iterator(mx.io.DataIter): | ||
'''multi label mnist iterator''' | ||
|
||
def __init__(self, data_iter): | ||
super(Multi_mnist_iterator, self).__init__() | ||
self.data_iter = data_iter | ||
self.batch_size = self.data_iter.batch_size | ||
|
||
@property | ||
def provide_data(self): | ||
return self.data_iter.provide_data | ||
|
||
@property | ||
def provide_label(self): | ||
provide_label = self.data_iter.provide_label[0] | ||
# Different labels should be used here for actual application | ||
return [('softmax1_label', provide_label[1]), \ | ||
('softmax2_label', provide_label[1])] | ||
|
||
def hard_reset(self): | ||
self.data_iter.hard_reset() | ||
|
||
def reset(self): | ||
self.data_iter.reset() | ||
|
||
def next(self): | ||
batch = self.data_iter.next() | ||
label = batch.label[0] | ||
|
||
return mx.io.DataBatch(data=batch.data, label=[label, label], \ | ||
pad=batch.pad, index=batch.index) | ||
|
||
class Multi_Accuracy(mx.metric.EvalMetric): | ||
"""Calculate accuracies of multi label""" | ||
|
||
def __init__(self, num=None): | ||
super(Multi_Accuracy, self).__init__('multi-accuracy', num) | ||
|
||
def update(self, labels, preds): | ||
mx.metric.check_label_shapes(labels, preds) | ||
|
||
if self.num != None: | ||
assert len(labels) == self.num | ||
|
||
for i in range(len(labels)): | ||
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') | ||
label = labels[i].asnumpy().astype('int32') | ||
|
||
mx.metric.check_label_shapes(label, pred_label) | ||
|
||
if i == None: | ||
self.sum_metric += (pred_label.flat == label.flat).sum() | ||
self.num_inst += len(pred_label.flat) | ||
else: | ||
self.sum_metric[i] += (pred_label.flat == label.flat).sum() | ||
self.num_inst[i] += len(pred_label.flat) | ||
|
||
|
||
batch_size=100 | ||
num_epochs=100 | ||
device = mx.gpu(0) | ||
lr = 0.01 | ||
|
||
network = build_network() | ||
train, val = mnist_iterator(batch_size=batch_size, input_shape = (784,)) | ||
train = Multi_mnist_iterator(train) | ||
val = Multi_mnist_iterator(val) | ||
|
||
|
||
model = mx.model.FeedForward( | ||
ctx = device, | ||
symbol = network, | ||
num_epoch = num_epochs, | ||
learning_rate = lr, | ||
momentum = 0.9, | ||
wd = 0.00001, | ||
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)) | ||
|
||
model.fit( | ||
X = train, | ||
eval_data = val, | ||
eval_metric = Multi_Accuracy(num=2), | ||
batch_end_callback = mx.callback.Speedometer(batch_size, 50)) | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Submodule mshadow
updated
5 files
+38 −0 | mshadow/cuda/tensor_gpu-inl.cuh | |
+7 −1 | mshadow/half.h | |
+21 −0 | mshadow/tensor.h | |
+9 −0 | mshadow/tensor_cpu-inl.h | |
+6 −0 | mshadow/tensor_gpu-inl.h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters