Skip to content

Commit

Permalink
Fixed bug in ReduceAxesBackwardUseInOut that caused finding the gradi…
Browse files Browse the repository at this point in the history
…ent (apache#4171)

of some reduce ops to fail.
Added test cases to exercise this code.
  • Loading branch information
alex-weaver authored and piiswrong committed Dec 29, 2016
1 parent d5ec111 commit 209a4ac
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
14 changes: 8 additions & 6 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,23 @@ void ReduceAxesBackwardUseInOut(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 2, DType> ograd =
inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
Tensor<xpu, 2, DType> data =
inputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
inputs[1].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
Tensor<xpu, 2, DType> out =
inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
ASSIGN_DISPATCH(igrad, req[0], ograd*F<OP>(data, broadcast_to(out, src_shape)));
inputs[2].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
ASSIGN_DISPATCH(igrad, req[0],
broadcast_to(ograd, src_shape)*F<OP>(data, broadcast_to(out, src_shape)));
} else {
const int ndim = MXNET_SPECIAL_MAX_NDIM;
Tensor<xpu, ndim, DType> igrad =
outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> ograd =
inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> data =
inputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
inputs[1].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
Tensor<xpu, ndim, DType> out =
inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
ASSIGN_DISPATCH(igrad, req[0], ograd*F<OP>(data, broadcast_to(out, src_shape)));
inputs[2].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
ASSIGN_DISPATCH(igrad, req[0],
broadcast_to(ograd, src_shape)*F<OP>(data, broadcast_to(out, src_shape)));
}
});
}
Expand Down
18 changes: 15 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,8 +1034,12 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym):
sum_groundtruth = np.array([sum_groundtruth])
grad_nd = mx.nd.empty(shape)
outgrad_npy = np.array(np.random.rand(*sum_groundtruth.shape))

keepdim_shape = np_reduce(dat_npy, axes, 1, np.sum).shape
grad_groundtruth = numpy_reduce_grad_func(outgrad=outgrad_npy, data=dat_npy,
axis=axes, keepdims=keepdims)
outdata=sum_groundtruth,
axis=axes, keepdims=keepdims,
keepdim_shape=keepdim_shape)
net = b.bind(default_context(), args={'a': mx.nd.array(dat_npy)},
args_grad={'a': grad_nd})
net.forward(is_train=True)
Expand All @@ -1046,9 +1050,17 @@ def test_reduce_inner(numpy_reduce_func, numpy_reduce_grad_func, mx_reduce_sym):
err_backward = reldiff(grad_nd.asnumpy(), grad_groundtruth)
assert err_backward < 1E-4
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.sum),
lambda outgrad, data, axis, keepdims:
outgrad.reshape(np_reduce(data, axis, 1, np.sum).shape),
lambda outgrad, data, outdata, axis, keepdims, keepdim_shape:
outgrad.reshape(keepdim_shape),
mx.symbol.sum)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.max),
lambda outgrad, data, outdata, axis, keepdims, keepdim_shape:
outgrad.reshape(keepdim_shape) * (np.equal(data, outdata.reshape(keepdim_shape)).astype(np.float)),
mx.symbol.max)
test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.min),
lambda outgrad, data, outdata, axis, keepdims, keepdim_shape:
outgrad.reshape(keepdim_shape) * (np.equal(data, outdata.reshape(keepdim_shape)).astype(np.float)),
mx.symbol.min)

def test_broadcast():
sample_num = 200
Expand Down

0 comments on commit 209a4ac

Please sign in to comment.