Skip to content

Commit

Permalink
add min max sum
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong committed Dec 9, 2015
1 parent f8123cd commit 66c6a54
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mshadow
31 changes: 30 additions & 1 deletion src/ndarray/unary_function-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ void L2Norm(const TBlob &src,
mshadow::VectorDot(out, in, in);
out = mshadow::expr::F<mxnet::op::mshadow_op::square_root>(out);
}

template<typename xpu, typename Reducer>
void Reduce(const TBlob &src,
TBlob *ret,
OpReqType req,
RunContext ctx) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
mshadow::Tensor<xpu, 1> out = ret->get<xpu, 1, real_t>(s);
mshadow::Tensor<xpu, 2> in =
src.get_with_shape<xpu, 2, real_t>(mshadow::Shape2(1, src.shape_.Size()), s);
out = mshadow::expr::reduce_except_dim<0, Reducer>(in);
}
// Register all unary operations here
// The true means inplace can be enabled.
// abs
Expand Down Expand Up @@ -148,7 +160,24 @@ MXNET_REGISTER_TBLOB_FUN(norm, XPU)
.set_shape_infer(ScalarShape)
.describe("Take L2 norm of the src."
"The result will be ndarray of shape (1,) on the same device.");

// Max
MXNET_REGISTER_TBLOB_FUN(max, XPU)
.set_function(XPU::kDevMask, Reduce<XPU, mshadow::red::maximum>, false, false)
.set_shape_infer(ScalarShape)
.describe("Take max of the src."
"The result will be ndarray of shape (1,) on the same device.");
// Min
MXNET_REGISTER_TBLOB_FUN(min, XPU)
.set_function(XPU::kDevMask, Reduce<XPU, mshadow::red::minimum>, false, false)
.set_shape_infer(ScalarShape)
.describe("Take min of the src."
"The result will be ndarray of shape (1,) on the same device.");
// Sum
MXNET_REGISTER_TBLOB_FUN(sum, XPU)
.set_function(XPU::kDevMask, Reduce<XPU, mshadow::red::sum>, false, false)
.set_shape_infer(ScalarShape)
.describe("Take sum of the src."
"The result will be ndarray of shape (1,) on the same device.");
} // namespace ndarray
} // namespace mxnet
#endif // MXNET_NDARRAY_UNARY_FUNCTION_INL_H_

0 comments on commit 66c6a54

Please sign in to comment.