Skip to content

Commit 40d193e

Browse files
authored
Add the ReLU6, Tanhshrink, SELU, Softplus, Softshrink and Softsign for the api 2.0 (#26376)
1 parent 6e13e86 commit 40d193e

File tree

10 files changed

+993
-116
lines changed

10 files changed

+993
-116
lines changed

paddle/fluid/operators/activation_op.cc

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -317,13 +317,6 @@ The OP square each elements of the inputs.
317317
318318
)DOC";
319319

320-
UNUSED constexpr char SoftplusDoc[] = R"DOC(
321-
Softplus Activation Operator.
322-
323-
$$out = \ln(1 + e^{x})$$
324-
325-
)DOC";
326-
327320
UNUSED constexpr char SoftsignDoc[] = R"DOC(
328321
Softsign Activation Operator.
329322
@@ -396,6 +389,36 @@ LeakyRelu Activation Operator.
396389
}
397390
};
398391

392+
class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
393+
public:
394+
void Make() override {
395+
AddInput("X",
396+
"Input of Softplus operator, an N-D Tensor, with data type "
397+
"float32, float64 or float16.");
398+
AddOutput(
399+
"Out",
400+
"Output of Softplus operator, a Tensor with shape same as input.");
401+
AddAttr<float>("beta", "The value of beta for Softplus.").SetDefault(1.0f);
402+
AddAttr<float>("threshold", "The value of threshold for Softplus.")
403+
.SetDefault(20.0f);
404+
AddAttr<bool>("use_mkldnn",
405+
"(bool, default false) Only used in mkldnn kernel.")
406+
.SetDefault(false);
407+
AddAttr<bool>(
408+
"use_cudnn",
409+
"(bool, default false) Only used in cudnn kernel, need install cudnn.")
410+
.SetDefault(false);
411+
AddComment(R"DOC(
412+
:strong:`Softplus Activation Operator`
413+
414+
.. math::
415+
out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
416+
\text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.
417+
418+
)DOC");
419+
}
420+
};
421+
399422
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
400423
public:
401424
void Make() override {
@@ -672,7 +695,6 @@ REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
672695
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
673696
REGISTER_ACTIVATION_OP_MAKER(Log1p, Log1pDoc);
674697
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
675-
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
676698
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
677699

678700
template <ActBwdOpFwdDeps kDepValue>

paddle/fluid/operators/activation_op.h

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -975,32 +975,46 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
975975
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
976976
};
977977

978-
// softplus(x) = log(1 + exp(x))
979-
// When x is a very large positive number, exp(x) may explode to inf,
980-
// Using trick below for numerical stability
981-
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
982-
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
978+
// For numerical stability, using the following formula instead of softplus(x) =
979+
// log(1 + exp(x))
980+
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
981+
// 1, threshold = 20 by default), otherwise x
983982
template <typename T>
984983
struct SoftplusFunctor : public BaseActivationFunctor<T> {
984+
float beta;
985+
float threshold;
986+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
987+
return {{"beta", &beta}, {"threshold", &threshold}};
988+
}
989+
985990
template <typename Device, typename X, typename Out>
986991
void operator()(Device d, X x, Out out) {
987-
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
988-
out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
992+
auto x_beta = static_cast<T>(beta) * x;
993+
out.device(d) = (x_beta > static_cast<T>(threshold))
994+
.select(x, (static_cast<T>(1) + x_beta.exp()).log() /
995+
static_cast<T>(beta));
989996
}
990997
};
991998

992-
// d(softplus(x))/dx = exp(x) / (1 + exp(x))
993-
// For numerical stability:
994-
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
995-
// exp(x - max(x, 0)))
999+
// For numerical stability, using the following formula instead of
1000+
// d(softplus(x))/dx = 1 / (1 + exp(-x))
1001+
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
1002+
// = 1, threshold = 20 by default), otherwise x
9961003
template <typename T>
9971004
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
1005+
float beta;
1006+
float threshold;
1007+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
1008+
return {{"beta", &beta}, {"threshold", &threshold}};
1009+
}
1010+
9981011
template <typename Device, typename X, typename Out, typename dOut,
9991012
typename dX>
10001013
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
1001-
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
1014+
auto x_beta = static_cast<T>(beta) * x;
10021015
dx.device(d) =
1003-
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
1016+
(x_beta > static_cast<T>(threshold))
1017+
.select(dout, dout / (static_cast<T>(1) + (-x_beta).exp()));
10041018
}
10051019

10061020
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }

python/paddle/fluid/layers/nn.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8643,11 +8643,9 @@ def relu(x, name=None):
86438643
return out
86448644

86458645

8646+
@deprecated(since="2.0.0", update_to="paddle.nn.functional.selu")
86468647
def selu(x, scale=None, alpha=None, name=None):
86478648
"""
8648-
:alias_main: paddle.nn.functional.selu
8649-
:alias: paddle.nn.functional.selu,paddle.nn.functional.activation.selu
8650-
:old_api: paddle.fluid.layers.selu
86518649

86528650
Selu Operator.
86538651

@@ -9304,12 +9302,9 @@ def elu(x, alpha=1.0, name=None):
93049302
return out
93059303

93069304

9307-
@templatedoc()
9305+
@deprecated(since="2.0.0", update_to="paddle.nn.functional.relu6")
93089306
def relu6(x, threshold=6.0, name=None):
93099307
"""
9310-
:alias_main: paddle.nn.functional.relu6
9311-
:alias: paddle.nn.functional.relu6,paddle.nn.functional.activation.relu6
9312-
:old_api: paddle.fluid.layers.relu6
93139308

93149309
${comment}
93159310

python/paddle/fluid/layers/ops.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
2121
from paddle.utils import deprecated
2222

23+
__deprecated_func_name__ = {'tanh_shrink': 'tanhshrink', }
24+
2325
__activations_noattr__ = [
2426
'sigmoid',
2527
'logsigmoid',
@@ -64,14 +66,20 @@
6466
__all__ += __unary_func__
6567

6668
for _OP in set(__activations_noattr__):
69+
_new_OP = _OP
70+
if _OP in __deprecated_func_name__:
71+
_new_OP = __deprecated_func_name__[_OP]
6772
func = generate_activation_fn(_OP)
6873
func = deprecated(
69-
since="2.0.0", update_to="paddle.nn.functional.%s" % (_OP))(func)
74+
since="2.0.0", update_to="paddle.nn.functional.%s" % (_new_OP))(func)
7075
globals()[_OP] = func
7176

7277
for _OP in set(__unary_func__):
78+
_new_OP = _OP
79+
if _OP in __deprecated_func_name__:
80+
_new_OP = __deprecated_func_name__[_OP]
7381
func = generate_activation_fn(_OP)
74-
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_OP))(func)
82+
func = deprecated(since="2.0.0", update_to="paddle.%s" % (_new_OP))(func)
7583
globals()[_OP] = func
7684

7785
add_sample_code(globals()["sigmoid"], r"""
@@ -160,16 +168,14 @@
160168
Examples:
161169
.. code-block:: python
162170
163-
import numpy as np
164171
import paddle
165172
import paddle.nn.functional as F
173+
import numpy as np
174+
166175
paddle.disable_static()
167176
168-
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
169-
x = paddle.to_variable(x_data)
170-
out = F.tanh_shrink(x)
171-
print(out.numpy())
172-
# [-0.02005104 -0.00262468 0.00033201 0.00868739]
177+
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
178+
out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
173179
174180
""")
175181

@@ -401,33 +407,29 @@
401407
Examples:
402408
.. code-block:: python
403409
404-
import numpy as np
405410
import paddle
406411
import paddle.nn.functional as F
412+
import numpy as np
413+
407414
paddle.disable_static()
408415
409-
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
410-
x = paddle.to_variable(x_data)
411-
out = F.softplus(x)
412-
print(out.numpy())
413-
# [0.51301525 0.59813887 0.74439666 0.85435524]
416+
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
417+
out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355]
414418
415419
""")
416420

417421
add_sample_code(globals()["softsign"], r"""
418422
Examples:
419423
.. code-block:: python
420424
421-
import numpy as np
422425
import paddle
423426
import paddle.nn.functional as F
427+
import numpy as np
428+
424429
paddle.disable_static()
425430
426-
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
427-
x = paddle.to_variable(x_data)
428-
out = F.softsign(x)
429-
print(out.numpy())
430-
# [-0.28571429 -0.16666667 0.09090909 0.23076923]
431+
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
432+
out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
431433
432434
""")
433435

0 commit comments

Comments
 (0)