Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad,
ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)

Expand All @@ -320,8 +320,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad,
RsqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(rsqrt_double_grad,
RsqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
SoftplusDoubleGradKernel)

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

PD_REGISTER_KERNEL(exp,
Expand Down
91 changes: 91 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,15 @@ struct RsqrtFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct RsqrtFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.rsqrt();
}
};

template <typename T>
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
Expand All @@ -784,6 +793,24 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct RsqrtGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<ComplexType<T>>(-0.5) * dout *
(out * out * out).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

// // For numerical stability, using the following formula instead of
// softplus(x) =
// // log(1 + exp(x))
Expand Down Expand Up @@ -3054,6 +3081,45 @@ struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct RsqrtGradGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* dX,
const DenseTensor* ddX,
DenseTensor* dOut,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));

// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) =
(static_cast<ComplexType<T>>(3.0) / out.unaryExpr(Conj<T>())) * dx *
ddx;
}
if (ddOut) {
auto ddout = EigenVector<ComplexType<T>>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<ComplexType<T>>(-0.5) *
(out * out * out).unaryExpr(Conj<T>());
}
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
Expand Down Expand Up @@ -4044,6 +4110,16 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaRsqrtFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// rsqrt(x) = rsqrt(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_x) const {
return static_cast<ComplexType<T>>(1.) / sqrt(arg_x);
}
};

template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand All @@ -4062,6 +4138,21 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaRsqrtGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = -0.5 * dout * out^3
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_dout, const ComplexType<T> arg_out) const {
return static_cast<ComplexType<T>>(-0.5) * arg_dout *
conj(arg_out * arg_out * arg_out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(rsqrt_double_grad,
RsqrtDoubleGradKernel)

PD_REGISTER_KERNEL(exp_grad,
GPU,
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

PD_REGISTER_KERNEL(exp,
Expand Down
14 changes: 12 additions & 2 deletions python/paddle/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def rsqrt(x, name=None):
out = \\frac{1}{\\sqrt{x}}

Args:
x (Tensor): Input of Rsqrt operator, an N-D Tensor, with data type float32, float64 or float16.
x (Tensor): Input of Rsqrt operator, an N-D Tensor, with data type float32, float64, float16, complex64 or complex128.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.

Returns:
Expand All @@ -882,7 +882,17 @@ def rsqrt(x, name=None):
return _C_ops.rsqrt(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'rsqrt'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring也添加上复数的两个类型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

],
'rsqrt',
)
helper = LayerHelper('rsqrt', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
Expand Down
53 changes: 53 additions & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,11 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(0.1, 1, self.shape)
+ 1j * np.random.uniform(0.1, 1, self.shape)
).astype(self.dtype)
out = 1.0 / np.sqrt(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand Down Expand Up @@ -1910,6 +1915,54 @@ def if_enable_cinn(self):
self.enable_cinn = False


class TestRsqrt_Complex64(TestRsqrt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有complex128类型的test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def init_dtype(self):
self.dtype = np.complex64

def test_check_grad(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU单测没运行,加下在GPU下运行的单测

self.check_grad(
['X'],
'Out',
check_pir=True,
max_relative_error=0.007,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个设置有点大,看能否改小点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

测试的时候误差是0.006多,这里设成了0.007

check_pir_onednn=self.check_pir_onednn,
)

def test_api_complex(self):
with dynamic_guard():
for device in devices:
if device == 'cpu' or (
device == 'gpu' and paddle.is_compiled_with_cuda()
):
np_x = np.array([[2, 3, 4], [7, 8, 9]], dtype=self.dtype)
x = paddle.to_tensor(np_x, dtype=self.dtype, place=device)
y = paddle.rsqrt(x)
x_expect = 1.0 / np.sqrt(np_x)
np.testing.assert_allclose(y.numpy(), x_expect, rtol=1e-3)

def test_grad_grad(self):
with dynamic_guard():
x_numpy = (
np.random.uniform(0.1, 1, self.shape)
+ 1j * np.random.uniform(0.1, 1, self.shape)
).astype(self.dtype)

expected_ddx = 3.0 / 4 * np.conj(np.power(x_numpy, -2.5))

x = paddle.to_tensor(x_numpy, stop_gradient=False)
y = paddle.rsqrt(x)
dx = paddle.grad(
outputs=[y], inputs=[x], create_graph=True, retain_graph=True
)[0]
ddx = paddle.grad(outputs=[dx], inputs=[x], retain_graph=True)[0]
np.testing.assert_allclose(ddx.numpy(), expected_ddx, rtol=1e-3)


class TestRsqrt_Complex128(TestRsqrt_Complex64):
def init_dtype(self):
self.dtype = np.complex128


class TestAbs(TestActivation):
def setUp(self):
self.op_type = "abs"
Expand Down