Skip to content

Commit 16c048a

Browse files
authored
Merge pull request #2 from veyron95/ops_derivative
native commit for triple grad of sigmod
2 parents 4c2a06d + d52b81c commit 16c048a

File tree

5 files changed

+806
-257
lines changed

5 files changed

+806
-257
lines changed

paddle/fluid/operators/activation_op.cc

Lines changed: 111 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,21 +68,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
6868
protected:
6969
void Apply(GradOpPtr<T> op) const override {
7070
op->SetType(this->ForwardOpType() + "_grad");
71-
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
72-
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
71+
op->SetInput(framework::GradVarName("Out"),
72+
this->OutputGrad("Out")); // dout
73+
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); // dx
7374
op->SetAttrMap(this->Attrs());
7475

7576
if ((static_cast<int>(kDepValue) &
7677
static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
7778
FLAGS_use_mkldnn ||
7879
(op->HasAttr("use_mkldnn") &&
7980
BOOST_GET_CONST(bool, op->GetAttr("use_mkldnn")))) {
80-
op->SetInput("X", this->Input("X"));
81+
op->SetInput("X", this->Input("X")); // x
8182
}
8283

8384
if (static_cast<int>(kDepValue) &
8485
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
85-
op->SetInput("Out", this->Output("Out"));
86+
op->SetInput("Out", this->Output("Out")); // out
8687
}
8788
}
8889
};
@@ -119,6 +120,7 @@ class ActivationOp : public framework::OperatorWithKernel {
119120
using framework::OperatorWithKernel::OperatorWithKernel;
120121

121122
void InferShape(framework::InferShapeContext* ctx) const override {
123+
VLOG(3) << "=========== in ActivationOp =========";
122124
ctx->ShareDim("X", /*->*/ "Out");
123125
ctx->ShareLoD("X", /*->*/ "Out");
124126
}
@@ -145,8 +147,9 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
145147
using framework::OperatorWithKernel::OperatorWithKernel;
146148

147149
void InferShape(framework::InferShapeContext* ctx) const override {
150+
VLOG(3) << "=========== in ActivationOpGrad =========";
148151
auto out_grad_name = framework::GradVarName("Out");
149-
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
152+
ctx->ShareDim(out_grad_name, framework::GradVarName("X")); // dout -> dx
150153
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
151154
}
152155

@@ -748,6 +751,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
748751
using framework::OperatorWithKernel::OperatorWithKernel;
749752

750753
void InferShape(framework::InferShapeContext* ctx) const override {
754+
VLOG(3) << "=========== in ActivationOpDoubleGrad =========";
751755
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
752756
if (ctx->HasOutput("DX")) {
753757
ctx->ShareDim("X", "DX");
@@ -804,6 +808,49 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
804808
}
805809
};
806810

811+
template <ActBwdOpFwdDeps kDepValue>
812+
class ActivationOpTribleGrad : public framework::OperatorWithKernel {
813+
public:
814+
using framework::OperatorWithKernel::OperatorWithKernel;
815+
816+
void InferShape(framework::InferShapeContext* ctx) const override {
817+
VLOG(3) << "=========== in ActivationOpTribleGrad =========";
818+
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
819+
if (ctx->HasOutput("DX")) {
820+
ctx->ShareDim("X", "DX");
821+
ctx->ShareLoD("X", "DX");
822+
}
823+
if (ctx->HasOutput("DDOut")) {
824+
ctx->ShareDim("X", "DDOut");
825+
ctx->ShareLoD("X", "DDOut");
826+
}
827+
}
828+
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
829+
if (ctx->HasOutput("D_DOut")) {
830+
ctx->ShareDim("Out", "D_DOut");
831+
ctx->ShareLoD("Out", "D_DOut");
832+
}
833+
if (ctx->HasOutput("D_OutNew")) {
834+
ctx->ShareDim("Out", "D_OutNew");
835+
ctx->ShareLoD("Out", "D_OutNew");
836+
}
837+
if (ctx->HasOutput("D_DDx")) {
838+
ctx->ShareDim("DDX", "D_DDx");
839+
ctx->ShareLoD("DDX", "D_DDx");
840+
}
841+
// op->SetOutput("D_OutNew", this->InputGrad("Out"));
842+
// op->SetOutput("D_DOut", this->InputGrad("DOut"));
843+
// op->SetOutput("D_DDx", this->InputGrad("DDX"));
844+
}
845+
}
846+
847+
protected:
848+
framework::OpKernelType GetExpectedKernelType(
849+
const framework::ExecutionContext& ctx) const override {
850+
return GetKernelType(ctx, *this, "DDX");
851+
}
852+
};
853+
807854
template <typename T>
808855
class SigmoidDoubleGradMaker
809856
: public ::paddle::framework::SingleGradOpMaker<T> {
@@ -812,6 +859,7 @@ class SigmoidDoubleGradMaker
812859

813860
protected:
814861
void Apply(GradOpPtr<T> op) const override {
862+
VLOG(3) << "=========== in SigmoidDoubleGradMaker =========";
815863
op->SetType("sigmoid_grad_grad");
816864
// input1: Out
817865
op->SetInput("Out", this->Input("Out"));
@@ -825,6 +873,37 @@ class SigmoidDoubleGradMaker
825873
}
826874
};
827875

876+
template <typename T>
877+
class SigmoidTribleGradMaker
878+
: public ::paddle::framework::SingleGradOpMaker<T> {
879+
public:
880+
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
881+
882+
protected:
883+
void Apply(GradOpPtr<T> op) const override {
884+
VLOG(3) << "=========== in SigmoidTribleGradMaker =========";
885+
op->SetType("sigmoid_trible_grad");
886+
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
887+
// D_OutNew, D_DOut, D_DDx // output
888+
// input1: Out
889+
op->SetInput("Out", this->Input("Out"));
890+
// input2: ddx
891+
op->SetInput("DDX", this->Input("DDX"));
892+
// input3: dout
893+
op->SetInput("DOut", this->Input("DOut"));
894+
// input4: d_ddout
895+
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
896+
// input5: d_dout_new
897+
op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
898+
op->SetAttrMap(this->Attrs());
899+
900+
// output: d_dOut, d_OutNew, d_ddx
901+
op->SetOutput("D_OutNew", this->InputGrad("Out"));
902+
op->SetOutput("D_DOut", this->InputGrad("DOut"));
903+
op->SetOutput("D_DDx", this->InputGrad("DDX"));
904+
}
905+
};
906+
828907
template <typename T>
829908
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
830909
public:
@@ -995,10 +1074,12 @@ class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
9951074
};
9961075

9971076
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInferer,
998-
{framework::GradVarName("Out"),
999-
framework::GradVarName("X")});
1077+
{framework::GradVarName("Out"), // dout
1078+
framework::GradVarName("X")}); // dx
10001079
DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
10011080
{"DDX", "DDOut"});
1081+
DECLARE_INPLACE_OP_INFERER(ActivationTribleGradOpInplaceInferer,
1082+
{"DDX", "D_DOut"});
10021083

10031084
template <typename T>
10041085
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
@@ -1121,13 +1202,21 @@ REGISTER_OPERATOR(
11211202
REGISTER_OPERATOR(sigmoid_grad, ops::ActivationOpGrad,
11221203
ops::ActivationGradOpInplaceInferer,
11231204
ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
1124-
ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>)
1205+
ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>);
11251206

11261207
// 3. Register Sigmoid DoubleGrad Operator
1127-
REGISTER_OPERATOR(
1128-
sigmoid_grad_grad,
1129-
ops::ActivationOpDoubleGrad<ops::SigmoidGradFunctor<float>::FwdDeps()>,
1130-
ops::ActivationDoubleGradOpInplaceInferer);
1208+
REGISTER_OPERATOR(sigmoid_grad_grad,
1209+
ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<
1210+
float>::FwdDeps()>, // 应该是 SigmoidGradGradFunctor
1211+
ops::ActivationDoubleGradOpInplaceInferer,
1212+
ops::SigmoidTribleGradMaker<paddle::framework::OpDesc>,
1213+
ops::SigmoidTribleGradMaker<paddle::imperative::OpBase>);
1214+
1215+
// 4. Register Sigmoid TribleGrad Operator
1216+
REGISTER_OPERATOR(sigmoid_trible_grad,
1217+
ops::ActivationOpTribleGrad<
1218+
ops::SigmoidTribleGradFunctor<float>::FwdDeps()>,
1219+
ops::ActivationTribleGradOpInplaceInferer);
11311220

11321221
// Register Sigmoid/GradSigmoid Kernels
11331222
REGISTER_ACTIVATION_CPU_KERNEL(sigmoid, Sigmoid, SigmoidFunctor,
@@ -1143,6 +1232,16 @@ REGISTER_OP_CPU_KERNEL(
11431232
ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
11441233
ops::SigmoidGradGradFunctor<plat::float16>>);
11451234

1235+
// Register TribleGrad Kernel
1236+
REGISTER_OP_CPU_KERNEL(
1237+
sigmoid_trible_grad,
1238+
ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1239+
ops::SigmoidTribleGradFunctor<float>>,
1240+
ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1241+
ops::SigmoidTribleGradFunctor<double>>,
1242+
ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1243+
ops::SigmoidTribleGradFunctor<plat::float16>>);
1244+
11461245
/* ========================================================================== */
11471246

11481247
/* ========================== tanh register ============================= */

paddle/fluid/operators/activation_op.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,15 @@ REGISTER_OP_CUDA_KERNEL(
13981398
ops::SigmoidGradGradFunctor<double>>,
13991399
ops::SigmoidDoubleGradKernel<plat::CUDADeviceContext,
14001400
ops::SigmoidGradGradFunctor<plat::float16>>);
1401+
1402+
REGISTER_OP_CUDA_KERNEL(
1403+
sigmoid_trible_grad,
1404+
ops::SigmoidTribleGradKernel<paddle::platform::CUDADeviceContext,
1405+
ops::SigmoidTribleGradFunctor<float>>,
1406+
ops::SigmoidTribleGradKernel<paddle::platform::CUDADeviceContext,
1407+
ops::SigmoidTribleGradFunctor<double>>,
1408+
ops::SigmoidTribleGradKernel<plat::CUDADeviceContext,
1409+
ops::SigmoidTribleGradFunctor<plat::float16>>);
14011410
/* ========================================================================== */
14021411

14031412
/* =========================== tanh register ============================ */

0 commit comments

Comments
 (0)