Skip to content

Commit

Permalink
Revert "account for batch padding when updating metrics (#7949)" (#8…
Browse files Browse the repository at this point in the history
…032)

* Revert "Update executor_group.py (#8003)"

This reverts commit 4aaefa0.

* Revert "account for batch padding when updating metrics (#7949)"

This reverts commit 5db5da9.
  • Loading branch information
piiswrong authored Sep 25, 2017
1 parent 32c588f commit 2e0ffdb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 62 deletions.
46 changes: 18 additions & 28 deletions python/mxnet/module/executor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,18 +96,6 @@ def _merge_multi_context(outputs, major_axis):
return rets


def _slice_axis(arr, axis, islice):
"""Slices array along axis"""
if axis == 0:
# slicing NDArray along axis 0 can avoid copying
return arr[islice]
elif axis > 0:
# pylint: disable=no-member
return nd.slice(arr, axis=axis, begin=islice.start, end=islice.stop)
# pylint: enable=no-member
return arr


class DataParallelExecutorGroup(object):
"""A group of executors that lives on a group of devices.
This is a helper class used to implement data parallelization. Each mini-batch will
Expand Down Expand Up @@ -218,7 +206,6 @@ def __init__(self, symbol, contexts, workload, data_shapes, label_shapes, param_

# initialize some instance variables
self.batch_size = None
self.cur_batch_pad = None
self.slices = None
self.execs = []
self._default_execs = None
Expand Down Expand Up @@ -413,8 +400,6 @@ def forward(self, data_batch, is_train=None):
"""
_load_data(data_batch, self.data_arrays, self.data_layouts)
self.cur_batch_pad = getattr(data_batch, 'pad', None)

if is_train is None:
is_train = self.for_training

Expand Down Expand Up @@ -572,23 +557,28 @@ def update_metric(self, eval_metric, labels):
The metric used for evaluation.
labels : list of NDArray
Typically comes from `label` of a `DataBatch`.
begin : int
Starting index of used outputs.
end : int or None
Ending index of used outputs.
"""
pad = self.cur_batch_pad or 0
valid_stop = self.batch_size - pad
for texec, islice in zip(self.execs, self.slices):
labels_slice = []
outputs_slice = []
if islice.start >= valid_stop:
break
if islice.stop > valid_stop:
islice = slice(islice.start, valid_stop)
oslice = slice(0, islice.stop - islice.start)
for label, laxis in zip(labels, self.label_layouts):
labels_slice.append(_slice_axis(label, laxis, islice))
for output, oaxis in zip(texec.outputs, self.output_layouts):
outputs_slice.append(_slice_axis(output, oaxis, oslice))
for label, axis in zip(labels, self.label_layouts):
if axis == 0:
# slicing NDArray along axis 0 can avoid copying
labels_slice.append(label[islice])
elif axis > 0:
# pylint: disable=no-member
label_my_slice = nd.slice_axis(label, axis=axis, begin=islice.start,
end=islice.stop).as_in_context(label.context)
# pylint: enable=no-member
labels_slice.append(label_my_slice)
else:
labels_slice.append(label)

labels_ = OrderedDict(zip(self.label_names, labels_slice))
preds = OrderedDict(zip(self.output_names, outputs_slice))
preds = OrderedDict(zip(self.output_names, texec.outputs))
eval_metric.update_dict(labels_, preds)

def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
Expand Down
34 changes: 0 additions & 34 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,40 +676,6 @@ def test_forward_reshape():
assert mod.get_outputs()[0].shape == (3, 5)


def test_forward_update_metric_pad():
d = mx.sym.Variable('data')
fc = mx.sym.FullyConnected(data=d, num_hidden=2)
sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')

dshape = (10, 2)
data = [mx.nd.zeros(dshape)]
lshape = (10,)
label = [mx.nd.array(([0] * 8) + ([1] * 2))]

eval_metric = mx.metric.Accuracy()
mod = mx.mod.Module(symbol=sym, data_names=['data'],
label_names=['softmax_label'])
mod.bind(data_shapes=[('data', dshape)],
label_shapes=[('softmax_label', lshape)])
mod.init_params(initializer=mx.initializer.Zero())
mod.init_optimizer(optimizer_params={'learning_rate': 0.01})

# Test that accuracy is 0.8 without padding
pad = None
data_batch = mx.io.DataBatch(data=data, label=label, pad=pad)
mod.forward(data_batch)
mod.update_metric(eval_metric, data_batch.label)
assert eval_metric.get()[1] == 0.8

# Test that accuracy is 1.0 with padding
pad = 2
eval_metric.reset()
data_batch = mx.io.DataBatch(data=data, label=label, pad=pad)
mod.forward(data_batch)
mod.update_metric(eval_metric, data_batch.label)
assert eval_metric.get()[1] == 1.0


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 2e0ffdb

Please sign in to comment.