Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #9 from dmlc/master
Browse files Browse the repository at this point in the history
Merge Back
  • Loading branch information
hetong007 committed Nov 2, 2015
2 parents 6b920f7 + 5f21337 commit 428ed60
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 68 deletions.
7 changes: 7 additions & 0 deletions example/numpy-ops/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Training MNIST With NumpyOp

Uses the same setup as example/mnist/mlp.py. Except the loss symbol is
custom defined with NumpyOp. mxnet.operator.NumpyOp help move computation
in a symbol's forward/backward operation to python frontend. This is for
fast implementation/experimentation of non-performance-critical symbols.
If it is becoming a bottleneck, please consider write a C++/CUDA version.
138 changes: 74 additions & 64 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments
# pylint: disable=invalid-name, protected-access, fixme, too-many-arguments, W0221, W0201

"""NDArray interface of mxnet"""
from __future__ import absolute_import

import ctypes
import sys
import numpy as np
import math
import logging
from .base import _LIB
from .base import c_array, c_str, mx_uint, py_str
Expand Down Expand Up @@ -55,9 +54,14 @@ def iter_next(self):
"""
pass

def getdata(self):
def getdata(self, index=0):
"""Get data of current batch.
Parameters
----------
index : int
The index of data source to retrieve.
Returns
-------
data : NDArray
Expand All @@ -73,7 +77,7 @@ def getlabel(self):
label : NDArray
The label of current batch.
"""
pass
return self.getdata(-1)

def getpad(self):
"""Get the number of padding examples in current batch.
Expand All @@ -91,11 +95,8 @@ class NDArrayIter(DataIter):
Parameters
----------
data : NDArray or numpy.ndarray
NDArray for data
label : NDArray or numpy.ndarray
NDArray for label
data_list or data, label: a list of, or two separate NDArray or numpy.ndarray
list of NDArray for data. The last one is treated as label.
batch_size: int
Batch Size
Expand All @@ -109,87 +110,96 @@ class NDArrayIter(DataIter):
label_pad_value: float, optionl
Padding value for label
last_batch_handle: 'pad', 'discard' or 'roll_over'
How to handle the last batch
Note
----
This iterator will pad the last batch if
the size of data does not match batch_size.
This iterator will pad, discard or roll over the last batch if
the size of data does not match batch_size. Roll over is intended
for training and can cause problems if used for prediction.
"""
def __init__(self, data, label,
batch_size,
shuffle=False,
data_pad_value=0,
label_pad_value=0):
def __init__(self, *args, **kwargs):
super(NDArrayIter, self).__init__()
if isinstance(data, NDArray):
data = data.asnumpy()
if isinstance(label, NDArray):
label = label.asnumpy()
if isinstance(args[0], list) or isinstance(args[0], tuple):
self._init(*args, **kwargs)
else:
self._init((args[0], args[1]), *args[2:], **kwargs)

def _init(self, data_list,
batch_size=1,
shuffle=False,
last_batch_handle='pad'):
"""Actual constructor"""
# pylint: disable=W0201
self.num_source = len(data_list)
assert self.num_source > 0, "Need at least one data source."
data_list = list(data_list)
for i in range(self.num_source):
if isinstance(data_list[i], NDArray):
data_list[i] = data_list[i].asnumpy()
# shuffle data
if shuffle:
idx = np.arange(data.shape[0])
idx = np.arange(data_list[0].shape[0])
np.random.shuffle(idx)
new_data = np.zeros(data.shape)
new_label = np.zeros(label.shape)
for i in range(data.shape[0]):
new_data[i] = data[idx[i]]
new_label[i] = label[idx[i]]
data = new_data
label = new_label
for i in range(self.num_source):
assert data_list[i].shape[0] == len(idx)
data_list[i] = data_list[i][idx]

# batching
self.batch_num = int(math.ceil(float(data.shape[0]) / batch_size))
batch_data_shape = []
batch_data_shape.append(self.batch_num)
batch_data_shape.append(batch_size)
for i in range(1, len(data.shape)):
batch_data_shape.append(data.shape[i])
batch_label_shape = []
batch_label_shape.append(self.batch_num)
batch_label_shape.append(batch_size)
for i in range(1, len(label.shape)):
batch_label_shape.append(label.shape[i])
self.batch_data = np.ones(batch_data_shape, dtype='float32') * data_pad_value
self.batch_label = np.ones(batch_label_shape, dtype='float32') * label_pad_value
loc = 0
for i in range(self.batch_num):
actual_size = min(data.shape[0] - loc, batch_size)
self.batch_data[i, 0:actual_size, ::] = data[loc:loc+actual_size, ::]
self.batch_label[i, 0:actual_size] = label[loc:loc+actual_size]
loc += batch_size
self.num_pad = batch_size - data.shape[0] % batch_size
if data.shape[0] % batch_size == 0:
self.num_pad = 0
self.out_data = None
self.out_label = None
self.current_batch = -1
if last_batch_handle == 'discard':
new_n = data_list[0].shape[0] - data_list[0].shape[0] % batch_size
for i in range(self.num_source):
data_list[i] = data_list[i][:new_n]
self.num_data = data_list[0].shape[0]
assert self.num_data > batch_size, \
"batch_size need to be smaller than data size when not padding."
self.cursor = -batch_size
self.data_list = data_list
self.batch_size = batch_size
self.last_batch_handle = last_batch_handle

def hard_reset(self):
"""Igore roll over data and set to start"""
self.cursor = -self.batch_size

def reset(self):
self.current_batch = -1
if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
else:
self.cursor = -self.batch_size

def iter_next(self):
if self.current_batch < self.batch_num - 1:
self.current_batch += 1
self.cursor += self.batch_size
if self.cursor < self.num_data:
return True
else:
return False

def next(self):
if self.iter_next():
return self.getdata(), self.getlabel()
return (self.getdata(i) for i in range(self.num_source))
else:
raise StopIteration

def getdata(self):
assert(self.current_batch >= 0)
return array(self.batch_data[self.current_batch])
def getdata(self, index=0):
assert(index < self.num_source)
assert(self.cursor < self.num_data), "DataIter needs reset."
if self.cursor + self.batch_size <= self.num_data:
return array(self.data_list[index][self.cursor:self.cursor+self.batch_size])
else:
pad = self.batch_size - self.num_data + self.cursor
return array(np.concatenate((self.data_list[index][self.cursor:],
self.data_list[index][:pad]),
axis=0))

def getlabel(self):
assert(self.current_batch >= 0)
return array(self.batch_label[self.current_batch])
return self.getdata(-1)

def getpad(self):
if self.current_batch == self.batch_num - 1:
return self.num_pad
if self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
else:
return 0

Expand Down
7 changes: 6 additions & 1 deletion python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,12 @@ def _init_iter(self, X, y, is_train):
y = y.flatten()
if y.ndim != 1:
raise ValueError("Label must be 1D or 2D (with 2nd dimension being 1)")
return io.NDArrayIter(X, y, self.numpy_batch_size, shuffle=is_train)
if is_train:
return io.NDArrayIter(X, y, self.numpy_batch_size,
shuffle=is_train, last_batch_handle='roll_over')
else:
return io.NDArrayIter(X, y, self.numpy_batch_size,
shuffle=is_train)
if not isinstance(X, io.DataIter):
raise TypeError('X must be DataIter, NDArray or numpy.ndarray')
return X
Expand Down
2 changes: 1 addition & 1 deletion src/operator/elementwise_sum-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enum ElementWiseSumOpOutputs {kOut};
struct ElementWiseSumParam : public dmlc::Parameter<ElementWiseSumParam> {
int num_args;
DMLC_DECLARE_PARAMETER(ElementWiseSumParam) {
DMLC_DECLARE_FIELD(num_args).set_range(1, 100)
DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
.describe("Number of inputs to be sumed.");
}
};
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,17 @@ def test_NDArrayIter():
for i in range(1000):
datas[i] = i / 100
labels[i] = i / 100
dataiter = mx.io.NDArrayIter(datas, labels, 128, True)
dataiter = mx.io.NDArrayIter(datas, labels, 128, True, last_batch_handle='pad')
batchidx = 0
for data, label in dataiter:
batchidx += 1
assert(batchidx == 8)
dataiter.reset()
dataiter = mx.io.NDArrayIter(datas, labels, 128, False, last_batch_handle='pad')
batchidx = 0
labelcount = [0 for i in range(10)]
for data, label in dataiter:
label = label.asnumpy().flatten()
assert((data.asnumpy()[:,0,0] == label).all())
for i in range(label.shape[0]):
labelcount[int(label[i])] += 1
for i in range(10):
Expand Down

0 comments on commit 428ed60

Please sign in to comment.