Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1513,6 +1513,10 @@ REGISTER_ACTIVATION_OP(tanh_shrink,
TanhShrinkFunctor,
TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
REGISTER_ACTIVATION_OP(softsign,
Softsign,
SoftsignFunctor,
SoftsignGradFunctor);
REGISTER_ACTIVATION_OP(hard_sigmoid,
HardSigmoid,
HardSigmoidFunctor,
Expand Down
32 changes: 3 additions & 29 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
USE_PHI_FUNCTOR(Softsign)
USE_PHI_FUNCTOR(Sigmoid)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sigmoid)
USE_PHI_TRIPLE_GRAD_FUNCTOR(Sigmoid)
Expand Down Expand Up @@ -493,35 +494,8 @@ inline void ExtractDoubleGradTensorWithInputDOut(
}
}

template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

} // namespace operators
} // namespace paddle

#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor);
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor);
24 changes: 1 addition & 23 deletions paddle/fluid/operators/activation_op.kps
Original file line number Diff line number Diff line change
Expand Up @@ -66,29 +66,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
return x / (one + abs(x));
}
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// dx = dout / (1 + abs(x))^2
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T temp = one + abs(x);
return dout / (temp * temp);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -174,6 +151,7 @@ USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
USE_PHI_FUNCTOR(CudaSoftsign)
USE_PHI_FUNCTOR(CudaSigmoid)
USE_PHI_FUNCTOR(CudaLogSigmoid)
USE_PHI_FUNCTOR(CudaHardSigmoid)
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2152,6 +2152,17 @@
use_gpudnn : true
backward : softmax_grad

# softsign
- api : softsign
args : (Tensor x)
output : Tensor
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : softsign
backward : softsign_grad

- api : spectral_norm
args : (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps)
output : Tensor
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,17 @@
func : softmax_grad
use_gpudnn : true

- backward_api : softsign_grad
forward : softsign (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : softsign_grad
inplace : (out_grad -> x_grad)

- backward_api : spectral_norm_grad
forward : spectral_norm (Tensor weight, Tensor u, Tensor v, int dim, int power_iters, float eps) -> Tensor(out)
args : (Tensor weight, Tensor u, Tensor v, Tensor out_grad, int dim, int power_iters, float eps)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/activation_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Atanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Silu);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Square);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Softsign);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Log2);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/activation_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ DECLARE_ACTIVATION_KERNEL(Reciprocal)
DECLARE_ACTIVATION_KERNEL(Square)
DECLARE_ACTIVATION_KERNEL(Sqrt)
DECLARE_ACTIVATION_KERNEL(Rsqrt)
DECLARE_ACTIVATION_KERNEL(Softsign)
DECLARE_ACTIVATION_KERNEL(Sigmoid)
DECLARE_ACTIVATION_KERNEL(LogSigmoid)
DECLARE_ACTIVATION_KERNEL(Log)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, Expm1GradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, ReciprocalGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, RsqrtGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, SoftsignGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid, LogSigmoidGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, LogGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Log2, Log2GradFunctor);
Expand Down Expand Up @@ -335,6 +336,7 @@ PD_REGISTER_KERNEL(square_double_grad,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ DEFINE_CPU_ACTIVATION_KERNEL(Reciprocal, ReciprocalFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Square, SquareFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sqrt, SqrtFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Rsqrt, RsqrtFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Softsign, SoftsignFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sigmoid, SigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(LogSigmoid, LogSigmoidFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
Expand Down Expand Up @@ -173,6 +174,7 @@ PD_REGISTER_KERNEL(expm1,
PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
PD_REGISTER_KERNEL(
square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {}
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
Expand Down
51 changes: 51 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,32 @@ struct SiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};

// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function

template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -3019,6 +3045,31 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// softsign(x) = x / (1 + abs(x))
__device__ __forceinline__ T operator()(const T x) const {
// Using abs directly will cause namespace conflict
return x / (one + (x > -x ? x : -x));
}
};

template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);

// dx = dout / (1 + abs(x))^2
__device__ __forceinline__ T operator()(const T dout, const T x) const {
// Using abs directly will cause namespace conflict
T temp = one + (x > -x ? x : -x);
return dout / (temp * temp);
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct CudaSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, CudaExpm1GradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, CudaReciprocalGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, CudaSqrtGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, CudaRsqrtGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, CudaSoftsignGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(LogSigmoid, CudaLogSigmoidGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, CudaLogGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Log2, CudaLog2GradFunctor);
Expand Down Expand Up @@ -415,6 +416,7 @@ PD_REGISTER_KERNEL(square_double_grad,
phi::dtype::float16,
phi::dtype::bfloat16) {}

PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ DEFINE_GPU_ACTIVATION_KERNEL(Reciprocal, CudaReciprocalFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Square, CudaSquareFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sqrt, CudaSqrtFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Rsqrt, CudaRsqrtFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Softsign, CudaSoftsignFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sigmoid, CudaSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(LogSigmoid, CudaLogSigmoidFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
Expand Down Expand Up @@ -241,6 +242,7 @@ PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/ops/compat/activation_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardShrink, "hard_shrink", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(TanhShrink, "tanh_shrink", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Silu, "silu", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Softsign, "softsign", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LogSigmoid, "logsigmoid", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT
Expand Down Expand Up @@ -294,6 +295,7 @@ PD_REGISTER_ARG_MAPPING_FN(elu, phi::EluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad, phi::EluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad_grad, phi::EluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(silu_grad, phi::SiluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softsign_grad, phi::SoftsignGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sigmoid_grad, phi::SigmoidGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sigmoid_grad_grad,
phi::SigmoidDoubleGradOpArgumentMapping);
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,7 @@ class TestSoftsign(TestActivation):
def setUp(self):
self.op_type = "softsign"
self.init_dtype()
self.python_api = paddle.nn.functional.softsign

np.random.seed(1024)
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
Expand All @@ -2805,7 +2806,7 @@ def setUp(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_eager=True)


class TestSoftsignAPI(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,8 @@ def softsign(x, name=None):
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
if in_dygraph_mode():
return _C_ops.final_state_softsign(x)
if in_dynamic_mode():
return _C_ops.softsign(x)

Expand Down