Skip to content

[cherry-pick] elu support alpha < 0 (#37316) #37437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,22 @@ Applies the following element-wise computation on the input according to
}
};

template <typename T>
class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elu_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("Out", this->Output("Out"));
op->SetInput("X", this->Input("X"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};

class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -1233,13 +1249,11 @@ REGISTER_OP_CPU_KERNEL(
/* ========================================================================== */

/* ======================== elu register ============================ */
REGISTER_OPERATOR(
elu, ops::ActivationOp, ops::ELUOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu, ops::ActivationOp, ops::ELUOpMaker,
ops::ActivationOpInferVarType,
ops::ELUGradOpMaker<paddle::framework::OpDesc>,
ops::ELUGradOpMaker<paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
Expand All @@ -1249,7 +1263,14 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CPU_KERNEL(elu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
elu_grad, ops::ELUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ELUGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<float>>,
Expand Down
91 changes: 76 additions & 15 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1161,11 +1161,12 @@ struct CudaELUFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}};
}

// elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
// elu(x) = x, if x > 0
// elu(x) = alpha * (e^x - 1), if x <= 0
__device__ __forceinline__ T operator()(const T& arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
CT res = x > zero ? x : temp;
return static_cast<T>(res);
}
};
Expand All @@ -1174,34 +1175,84 @@ template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType a = static_cast<MPType>(alpha);
MPType out_pos = static_cast<MPType>(out > zero);
MPType out_neg = static_cast<MPType>(out <= zero);
return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};

template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
__device__ __forceinline__ T operator()(const T& arg_dout, const T& arg_out,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
temp_a_neg * temp_x_pos * (one + a * exp(x))));
MPType x_pos = static_cast<MPType>(x > zero);
MPType x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename DeviceContext, typename T>
class ELUGradCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<framework::Tensor>("Out");
auto* x = ctx.Input<framework::Tensor>("X");
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const float alpha = ctx.Attr<float>("alpha");

auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<const framework::Tensor*> ins = {d_out, out};
std::vector<framework::Tensor*> outs = {d_x};
if (alpha > 0) {
CudaELUGradFunctor<T> functor;
functor.alpha = alpha;
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor);
} else {
CudaELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
ins.push_back(x);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor);
}
}
};

template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -1330,7 +1381,17 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */

/* ======================== elu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
elu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaELUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
elu_grad, ops::ELUGradCudaKernel<plat::CUDADeviceContext, float>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, double>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, plat::float16>);

REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
Expand Down
73 changes: 59 additions & 14 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1311,25 +1311,70 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();

// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
temp_a_neg * temp_x_pos;
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx.device(d) = (out > static_cast<T>(0))
.select(dout, dout * (out + static_cast<T>(alpha)));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx.device(d) = (x > static_cast<T>(0))
.select(dout, dout * static_cast<T>(alpha) * x.exp());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
const float alpha = context.Attr<float>("alpha");
dX->mutable_data<T>(context.GetPlace());

auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "elu_grad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad"));
auto* place =
context.template device_context<DeviceContext>().eigen_device();

if (alpha > 0) {
ELUGradFunctor<T> functor;
functor.alpha = alpha;
functor(*place, x, out, dout, dx);
} else {
ELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
functor(*place, x, out, dout, dx);
}
}
};

// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/inplace_abn_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class InplaceABNActivation {
auto temp2 = (y * temp / static_cast<T>(alpha) + static_cast<T>(1)).log();
x.device(d) = (y * temp1 + temp2).template cast<T>();

ELUGradFunctor<T> functor;
ELUGradNegativeAlphaFunctor<T> functor;
compute(ctx, &functor, d, x, y, dy, dx);
} else {
PADDLE_THROW(
Expand Down
18 changes: 16 additions & 2 deletions python/paddle/fluid/tests/unittests/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,7 @@ def test_errors(self):


def elu(x, alpha):
out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
out_ref = np.where(x > 0, x, alpha * (np.exp(x) - 1))
return out_ref.astype(x.dtype)


Expand All @@ -1753,7 +1753,7 @@ def setUp(self):

np.random.seed(1024)
x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype)
alpha = 1.
alpha = self.get_alpha()
out = elu(x, alpha)
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
Expand All @@ -1766,6 +1766,14 @@ def test_check_grad(self):
return
self.check_grad(['X'], 'Out')

def get_alpha(self):
return 1.


class TestELUAlpha(TestELU):
def get_alpha(self):
return -0.2


class TestELUAPI(unittest.TestCase):
# test paddle.nn.ELU, paddle.nn.functional.elu
Expand Down Expand Up @@ -1832,6 +1840,12 @@ class TestELUInplaceAPI(TestELUAPI):
def executed_api(self):
self.elu = F.elu_

def test_alpha_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
self.assertRaises(Exception, F.elu_, x, -0.2)
paddle.enable_static()


class TestReciprocal(TestActivation):
def setUp(self):
Expand Down
9 changes: 8 additions & 1 deletion python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ def elu(x, alpha=1.0, name=None):

.. math::

elu(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
elu(x)=
\left\{
\begin{array}{lcl}
x,& &\text{if } \ x > 0 \\
alpha * (e^{x} - 1),& &\text{if } \ x <= 0
\end{array}
\right.

Parameters:
x (Tensor): The input Tensor with data type float32, float64.
Expand Down Expand Up @@ -80,6 +86,7 @@ def elu_(x, alpha=1.0, name=None):
Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_elu`.
"""
assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead."
return _C_ops.elu_(x, 'alpha', alpha)


Expand Down
8 changes: 7 additions & 1 deletion python/paddle/nn/layer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ class ELU(Layer):

.. math::

ELU(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
ELU(x)=
\left\{
\begin{array}{lcl}
x,& &\text{if } \ x > 0 \\
alpha * (e^{x} - 1),& &\text{if } \ x <= 0
\end{array}
\right.

Parameters:
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
Expand Down