Skip to content

Commit

Permalink
[complex] add complex support for silu (#56903)
Browse files Browse the repository at this point in the history
  • Loading branch information
BeingGod authored Sep 7, 2023
1 parent 5ed7c8a commit 1d4e938
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 7 deletions.
10 changes: 10 additions & 0 deletions paddle/phi/common/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,16 @@ HOSTDEVICE inline complex<T> log(const complex<T>& a) {
#endif
}

template <typename T>
HOSTDEVICE inline complex<T> exp(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::exp(thrust::complex<T>(a)));
#else
return complex<T>(std::exp(std::complex<T>(a)));
#endif
}

template <typename T>
inline std::ostream& operator<<(std::ostream& os, const complex<T>& a) {
os << "real:" << a.real << " imag:" << a.imag;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
Expand Down
36 changes: 36 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,25 @@ struct SiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SiluGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto temp1 = static_cast<ComplexType<T>>(1) + (-x).exp(); // 1+e^(-x)
auto temp2 = x * (-x).exp(); // x*e^(-x)
dx.device(d) = dout * ((static_cast<ComplexType<T>>(1) / temp1) *
(static_cast<ComplexType<T>>(1) + (temp2 / temp1)))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
Expand Down Expand Up @@ -3793,6 +3812,23 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSiluGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);

// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_dout, const ComplexType<T> arg_x) const {
ComplexType<T> dout = static_cast<ComplexType<T>>(arg_dout);
ComplexType<T> x = static_cast<ComplexType<T>>(arg_x);
ComplexType<T> temp = one / (one + exp(-x));
return dout * conj(temp * (one + x * (one - temp)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ PD_REGISTER_KERNEL(exp_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(logit_grad, LogitCUDAGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def silu(x, name=None):
Where :math:`x` is the input Tensor.
Parameters:
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64.
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, complex64, complex128.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Expand All @@ -1057,7 +1057,17 @@ def silu(x, name=None):
return _C_ops.silu(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'silu'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'silu',
)
helper = LayerHelper("silu", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
Expand Down
21 changes: 20 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = x / (np.exp(-x) + 1)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
Expand All @@ -397,14 +402,28 @@ def if_enable_cinn(self):
pass

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)


class TestSilu_ZeroDim(TestSilu):
def init_shape(self):
self.shape = []


class TestSilu_Complex64(TestSilu):
def init_dtype(self):
self.dtype = np.complex64


class TestSilu_Complex128(TestSilu):
def init_dtype(self):
self.dtype = np.complex128


class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
def setUp(self):
Expand Down

0 comments on commit 1d4e938

Please sign in to comment.