Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add Mathematical functions (#3317)
Browse files Browse the repository at this point in the history
  • Loading branch information
yajiedesign authored and piiswrong committed Dec 29, 2016
1 parent 758d529 commit 550fff1
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 4 deletions.
64 changes: 60 additions & 4 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,20 @@ struct log_grad {
}
};

struct sin {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(sinf(a));
}
};

struct sin_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(cosf(a));
}
};

struct cos {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
Expand All @@ -168,17 +182,59 @@ struct cos_grad {
}
};

struct sin {
struct tan {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(sinf(a));
return DType(tanf(a));
}
};

struct sin_grad {
struct tan_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(cosf(a));
return DType(powf(a, 2) + 1);
}
};

struct arcsin {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(asinf(a));
}
};

struct arcsin_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1.0 / (sqrtf(1 - a*a)));
}
};

struct arccos {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(acosf(a));
}
};

struct arccos_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(-1.0 / (sqrtf(1 - a*a)));
}
};

struct arctan {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(atanf(a));
}
};

struct arctan_grad {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
return DType(1 / (a*a + 1));
}
};
struct square {
Expand Down
36 changes: 36 additions & 0 deletions src/operator/tensor/elemwise_unary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,41 @@ MXNET_OPERATOR_REGISTER_UNARY(sin)
MXNET_OPERATOR_REGISTER_BINARY(_backward_sin)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::sin_grad> >);

// tan
MXNET_OPERATOR_REGISTER_UNARY(tan)
.MXNET_DESCRIBE("Take tan of the src")
.set_attr<FCompute>("FCompute<cpu>", UnaryCompute<cpu, mshadow_op::tan>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tan" });

MXNET_OPERATOR_REGISTER_BINARY(_backward_tan)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::tan_grad> >);

// arcsin
MXNET_OPERATOR_REGISTER_UNARY(arcsin)
.MXNET_DESCRIBE("Take arcsin of the src")
.set_attr<FCompute>("FCompute<cpu>", UnaryCompute<cpu, mshadow_op::arcsin>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arcsin" });

MXNET_OPERATOR_REGISTER_BINARY(_backward_arcsin)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::arcsin_grad> >);

// arccos
MXNET_OPERATOR_REGISTER_UNARY(arccos)
.MXNET_DESCRIBE("Take arccos of the src")
.set_attr<FCompute>("FCompute<cpu>", UnaryCompute<cpu, mshadow_op::arccos>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arccos" });

MXNET_OPERATOR_REGISTER_BINARY(_backward_arccos)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::arccos_grad> >);

// arctan
MXNET_OPERATOR_REGISTER_UNARY(arctan)
.MXNET_DESCRIBE("Take arctan of the src")
.set_attr<FCompute>("FCompute<cpu>", UnaryCompute<cpu, mshadow_op::arctan>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arctan" });

MXNET_OPERATOR_REGISTER_BINARY(_backward_arctan)
.set_attr<FCompute>("FCompute<cpu>", BinaryCompute<cpu, unary_bwd<mshadow_op::arctan_grad> >);

} // namespace op
} // namespace mxnet
28 changes: 28 additions & 0 deletions src/operator/tensor/elemwise_unary_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,33 @@ NNVM_REGISTER_OP(sin)
NNVM_REGISTER_OP(_backward_sin)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, unary_bwd<mshadow_op::sin_grad> >);

// tan
NNVM_REGISTER_OP(tan)
.set_attr<FCompute>("FCompute<gpu>", UnaryCompute<gpu, mshadow_op::tan>);

NNVM_REGISTER_OP(_backward_tan)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, unary_bwd<mshadow_op::tan_grad> >);

// arcsin
NNVM_REGISTER_OP(arcsin)
.set_attr<FCompute>("FCompute<gpu>", UnaryCompute<gpu, mshadow_op::arcsin>);

NNVM_REGISTER_OP(_backward_arcsin)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, unary_bwd<mshadow_op::arcsin_grad> >);

// arccos
NNVM_REGISTER_OP(arccos)
.set_attr<FCompute>("FCompute<gpu>", UnaryCompute<gpu, mshadow_op::arccos>);

NNVM_REGISTER_OP(_backward_arccos)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, unary_bwd<mshadow_op::arccos_grad> >);

// arctan
NNVM_REGISTER_OP(arctan)
.set_attr<FCompute>("FCompute<gpu>", UnaryCompute<gpu, mshadow_op::arctan>);

NNVM_REGISTER_OP(_backward_arccos)
.set_attr<FCompute>("FCompute<gpu>", BinaryCompute<gpu, unary_bwd<mshadow_op::arctan_grad> >);

} // namespace op
} // namespace mxnet
50 changes: 50 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,55 @@ def test_sequence_mask():
check_sequence_mask(shape1, default_context(), 2.1)
check_sequence_mask(shape2, default_context(), 0.1)

def mathematical_core(name, forward_mxnet_call, forward_numpy_call, backward_numpy_call, data_init=5., grad_init=2.):
data = mx.symbol.Variable('data')
shape = (3, 4)
data_tmp = np.ones(shape)
data_tmp[:] = data_init
arr_data = mx.nd.array(data_tmp)
arr_grad = mx.nd.empty(shape)
arr_grad[:] = 3

test = forward_mxnet_call(data)
exe_test = test.bind(mx.cpu(), args=[arr_data], args_grad=[arr_grad])
exe_test.forward()
out = exe_test.outputs[0].asnumpy()
npout = forward_numpy_call(data_tmp)
assert reldiff(out, npout) < 1e-6, "%s mathematical forward failed\n%s\n\n%s" % (name, out, npout)

out_grad = mx.nd.empty(shape)
out_grad[:] = grad_init
npout_grad = out_grad.asnumpy()
temp = backward_numpy_call(data_tmp)
npout_grad = npout_grad * temp
exe_test.backward(out_grad)
arr_grad = arr_grad.asnumpy()
# print(name)
# print(arr_grad)
# print(npout_grad)
assert reldiff(arr_grad, npout_grad) < 1e-6, "%s mathematical backward failed\n%s\n\n%s" % (
name, arr_grad, npout_grad)


def test_mathematical():
# rsqrt
mathematical_core("rsqrt",
lambda x: mx.sym.rsqrt(x),
lambda x: 1 / np.sqrt(x),
lambda x: -(1.0 / (2.0 * x * np.sqrt(x))))
# tan
mathematical_core("tan", lambda x: mx.sym.tan(x), lambda x: np.tan(x), lambda x: np.tan(x) ** 2 + 1)
# arcsin
mathematical_core("arcsin", lambda x: mx.sym.arcsin(x), lambda x: np.arcsin(x),
lambda x: 1. / (1. - x ** 2) ** (1. / 2.), 0.5, 0.5)
# arccos
mathematical_core("arccos", lambda x: mx.sym.arccos(x), lambda x: np.arccos(x),
lambda x: -1. / (1. - x ** 2.) ** (1. / 2.), 0.5, 0.5)
# arctan
mathematical_core("arctan", lambda x: mx.sym.arctan(x), lambda x: np.arctan(x),
lambda x: 1. / (x ** 2. + 1.), 0.5, 0.5)


if __name__ == '__main__':
test_expand_dims()
test_slice_axis()
Expand Down Expand Up @@ -1664,3 +1713,4 @@ def test_sequence_mask():
test_instance_normalization()
test_l2_normalization()
test_sequence_mask()
test_mathematical()

0 comments on commit 550fff1

Please sign in to comment.