diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index 65c261bb..0f3c079f 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -96,18 +96,6 @@ def _merge_multi_context(outputs, major_axis): return rets -def _slice_axis(arr, axis, islice): - """Slices array along axis""" - if axis == 0: - # slicing NDArray along axis 0 can avoid copying - return arr[islice] - elif axis > 0: - # pylint: disable=no-member - return nd.slice(arr, axis=axis, begin=islice.start, end=islice.stop) - # pylint: enable=no-member - return arr - - class DataParallelExecutorGroup(object): """A group of executors that lives on a group of devices. This is a helper class used to implement data parallelization. Each mini-batch will @@ -218,7 +206,6 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_ # initialize some instance variables self.batch_size = None - self.cur_batch_pad = None self.slices = None self.execs = [] self._default_execs = None @@ -413,8 +400,6 @@ def forward(self, data_batch, is_train=None): """ _load_data(data_batch, self.data_arrays, self.data_layouts) - self.cur_batch_pad = getattr(data_batch, 'pad', None) - if is_train is None: is_train = self.for_training @@ -572,23 +557,28 @@ def update_metric(self, eval_metric, labels): The metric used for evaluation. labels : list of NDArray Typically comes from `label` of a `DataBatch`. + begin : int + Starting index of used outputs. + end : int or None + Ending index of used outputs. """ - pad = self.cur_batch_pad or 0 - valid_stop = self.batch_size - pad for texec, islice in zip(self.execs, self.slices): labels_slice = [] - outputs_slice = [] - if islice.start >= valid_stop: - break - if islice.stop > valid_stop: - islice = slice(islice.start, valid_stop) - oslice = slice(0, islice.stop - islice.start) - for label, laxis in zip(labels, self.label_layouts): - labels_slice.append(_slice_axis(label, laxis, islice)) - for output, oaxis in zip(texec.outputs, self.output_layouts): - outputs_slice.append(_slice_axis(output, oaxis, oslice)) + for label, axis in zip(labels, self.label_layouts): + if axis == 0: + # slicing NDArray along axis 0 can avoid copying + labels_slice.append(label[islice]) + elif axis > 0: + # pylint: disable=no-member + label_my_slice = nd.slice_axis(label, axis=axis, begin=islice.start, + end=islice.stop).as_in_context(label.context) + # pylint: enable=no-member + labels_slice.append(label_my_slice) + else: + labels_slice.append(label) + labels_ = OrderedDict(zip(self.label_names, labels_slice)) - preds = OrderedDict(zip(self.output_names, outputs_slice)) + preds = OrderedDict(zip(self.output_names, texec.outputs)) eval_metric.update_dict(labels_, preds) def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group): diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 7cbbbabc..6813c48a 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -676,40 +676,6 @@ def test_forward_reshape(): assert mod.get_outputs()[0].shape == (3, 5) -def test_forward_update_metric_pad(): - d = mx.sym.Variable('data') - fc = mx.sym.FullyConnected(data=d, num_hidden=2) - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - - dshape = (10, 2) - data = [mx.nd.zeros(dshape)] - lshape = (10,) - label = [mx.nd.array(([0] * 8) + ([1] * 2))] - - eval_metric = mx.metric.Accuracy() - mod = mx.mod.Module(symbol=sym, data_names=['data'], - label_names=['softmax_label']) - mod.bind(data_shapes=[('data', dshape)], - label_shapes=[('softmax_label', lshape)]) - mod.init_params(initializer=mx.initializer.Zero()) - mod.init_optimizer(optimizer_params={'learning_rate': 0.01}) - - # Test that accuracy is 0.8 without padding - pad = None - data_batch = mx.io.DataBatch(data=data, label=label, pad=pad) - mod.forward(data_batch) - mod.update_metric(eval_metric, data_batch.label) - assert eval_metric.get()[1] == 0.8 - - # Test that accuracy is 1.0 with padding - pad = 2 - eval_metric.reset() - data_batch = mx.io.DataBatch(data=data, label=label, pad=pad) - mod.forward(data_batch) - mod.update_metric(eval_metric, data_batch.label) - assert eval_metric.get()[1] == 1.0 - - if __name__ == '__main__': import nose nose.runmodule()