@@ -68,21 +68,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
6868 protected:
6969 void Apply (GradOpPtr<T> op) const override {
7070 op->SetType (this ->ForwardOpType () + " _grad" );
71- op->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
72- op->SetOutput (framework::GradVarName (" X" ), this ->InputGrad (" X" ));
71+ op->SetInput (framework::GradVarName (" Out" ),
72+ this ->OutputGrad (" Out" )); // dout
73+ op->SetOutput (framework::GradVarName (" X" ), this ->InputGrad (" X" )); // dx
7374 op->SetAttrMap (this ->Attrs ());
7475
7576 if ((static_cast <int >(kDepValue ) &
7677 static_cast <int >(ActBwdOpFwdDeps::kDepX )) ||
7778 FLAGS_use_mkldnn ||
7879 (op->HasAttr (" use_mkldnn" ) &&
7980 BOOST_GET_CONST (bool , op->GetAttr (" use_mkldnn" )))) {
80- op->SetInput (" X" , this ->Input (" X" ));
81+ op->SetInput (" X" , this ->Input (" X" )); // x
8182 }
8283
8384 if (static_cast <int >(kDepValue ) &
8485 static_cast <int >(ActBwdOpFwdDeps::kDepOut )) {
85- op->SetInput (" Out" , this ->Output (" Out" ));
86+ op->SetInput (" Out" , this ->Output (" Out" )); // out
8687 }
8788 }
8889};
@@ -119,6 +120,7 @@ class ActivationOp : public framework::OperatorWithKernel {
119120 using framework::OperatorWithKernel::OperatorWithKernel;
120121
121122 void InferShape (framework::InferShapeContext* ctx) const override {
123+ VLOG (3 ) << " =========== in ActivationOp =========" ;
122124 ctx->ShareDim (" X" , /* ->*/ " Out" );
123125 ctx->ShareLoD (" X" , /* ->*/ " Out" );
124126 }
@@ -145,8 +147,9 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
145147 using framework::OperatorWithKernel::OperatorWithKernel;
146148
147149 void InferShape (framework::InferShapeContext* ctx) const override {
150+ VLOG (3 ) << " =========== in ActivationOpGrad =========" ;
148151 auto out_grad_name = framework::GradVarName (" Out" );
149- ctx->ShareDim (out_grad_name, framework::GradVarName (" X" ));
152+ ctx->ShareDim (out_grad_name, framework::GradVarName (" X" )); // dout -> dx
150153 ctx->ShareLoD (out_grad_name, framework::GradVarName (" X" ));
151154 }
152155
@@ -748,6 +751,7 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
748751 using framework::OperatorWithKernel::OperatorWithKernel;
749752
750753 void InferShape (framework::InferShapeContext* ctx) const override {
754+ VLOG (3 ) << " =========== in ActivationOpDoubleGrad =========" ;
751755 if (static_cast <int >(kDepValue ) & static_cast <int >(kDepX )) {
752756 if (ctx->HasOutput (" DX" )) {
753757 ctx->ShareDim (" X" , " DX" );
@@ -804,6 +808,49 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
804808 }
805809};
806810
811+ template <ActBwdOpFwdDeps kDepValue >
812+ class ActivationOpTribleGrad : public framework ::OperatorWithKernel {
813+ public:
814+ using framework::OperatorWithKernel::OperatorWithKernel;
815+
816+ void InferShape (framework::InferShapeContext* ctx) const override {
817+ VLOG (3 ) << " =========== in ActivationOpTribleGrad =========" ;
818+ if (static_cast <int >(kDepValue ) & static_cast <int >(kDepX )) {
819+ if (ctx->HasOutput (" DX" )) {
820+ ctx->ShareDim (" X" , " DX" );
821+ ctx->ShareLoD (" X" , " DX" );
822+ }
823+ if (ctx->HasOutput (" DDOut" )) {
824+ ctx->ShareDim (" X" , " DDOut" );
825+ ctx->ShareLoD (" X" , " DDOut" );
826+ }
827+ }
828+ if (static_cast <int >(kDepValue ) & static_cast <int >(kDepOut )) {
829+ if (ctx->HasOutput (" D_DOut" )) {
830+ ctx->ShareDim (" Out" , " D_DOut" );
831+ ctx->ShareLoD (" Out" , " D_DOut" );
832+ }
833+ if (ctx->HasOutput (" D_OutNew" )) {
834+ ctx->ShareDim (" Out" , " D_OutNew" );
835+ ctx->ShareLoD (" Out" , " D_OutNew" );
836+ }
837+ if (ctx->HasOutput (" D_DDx" )) {
838+ ctx->ShareDim (" DDX" , " D_DDx" );
839+ ctx->ShareLoD (" DDX" , " D_DDx" );
840+ }
841+ // op->SetOutput("D_OutNew", this->InputGrad("Out"));
842+ // op->SetOutput("D_DOut", this->InputGrad("DOut"));
843+ // op->SetOutput("D_DDx", this->InputGrad("DDX"));
844+ }
845+ }
846+
847+ protected:
848+ framework::OpKernelType GetExpectedKernelType (
849+ const framework::ExecutionContext& ctx) const override {
850+ return GetKernelType (ctx, *this , " DDX" );
851+ }
852+ };
853+
807854template <typename T>
808855class SigmoidDoubleGradMaker
809856 : public ::paddle::framework::SingleGradOpMaker<T> {
@@ -812,6 +859,7 @@ class SigmoidDoubleGradMaker
812859
813860 protected:
814861 void Apply (GradOpPtr<T> op) const override {
862+ VLOG (3 ) << " =========== in SigmoidDoubleGradMaker =========" ;
815863 op->SetType (" sigmoid_grad_grad" );
816864 // input1: Out
817865 op->SetInput (" Out" , this ->Input (" Out" ));
@@ -825,6 +873,37 @@ class SigmoidDoubleGradMaker
825873 }
826874};
827875
876+ template <typename T>
877+ class SigmoidTribleGradMaker
878+ : public ::paddle::framework::SingleGradOpMaker<T> {
879+ public:
880+ using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
881+
882+ protected:
883+ void Apply (GradOpPtr<T> op) const override {
884+ VLOG (3 ) << " =========== in SigmoidTribleGradMaker =========" ;
885+ op->SetType (" sigmoid_trible_grad" );
886+ // Out, DDX, DOut, D_DDOut, D_DOut_New // input
887+ // D_OutNew, D_DOut, D_DDx // output
888+ // input1: Out
889+ op->SetInput (" Out" , this ->Input (" Out" ));
890+ // input2: ddx
891+ op->SetInput (" DDX" , this ->Input (" DDX" ));
892+ // input3: dout
893+ op->SetInput (" DOut" , this ->Input (" DOut" ));
894+ // input4: d_ddout
895+ op->SetInput (" D_DDOut" , this ->OutputGrad (" DDOut" ));
896+ // input5: d_dout_new
897+ op->SetInput (" D_DOut_New" , this ->OutputGrad (" DOutNew" ));
898+ op->SetAttrMap (this ->Attrs ());
899+
900+ // output: d_dOut, d_OutNew, d_ddx
901+ op->SetOutput (" D_OutNew" , this ->InputGrad (" Out" ));
902+ op->SetOutput (" D_DOut" , this ->InputGrad (" DOut" ));
903+ op->SetOutput (" D_DDx" , this ->InputGrad (" DDX" ));
904+ }
905+ };
906+
828907template <typename T>
829908class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
830909 public:
@@ -995,10 +1074,12 @@ class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
9951074};
9961075
9971076DECLARE_INPLACE_OP_INFERER (ActivationGradOpInplaceInferer,
998- {framework::GradVarName (" Out" ),
999- framework::GradVarName (" X" )});
1077+ {framework::GradVarName (" Out" ), // dout
1078+ framework::GradVarName (" X" )}); // dx
10001079DECLARE_INPLACE_OP_INFERER (ActivationDoubleGradOpInplaceInferer,
10011080 {" DDX" , " DDOut" });
1081+ DECLARE_INPLACE_OP_INFERER (ActivationTribleGradOpInplaceInferer,
1082+ {" DDX" , " D_DOut" });
10021083
10031084template <typename T>
10041085class PowGradOpMaker : public framework ::SingleGradOpMaker<T> {
@@ -1121,13 +1202,21 @@ REGISTER_OPERATOR(
11211202REGISTER_OPERATOR (sigmoid_grad, ops::ActivationOpGrad,
11221203 ops::ActivationGradOpInplaceInferer,
11231204 ops::SigmoidDoubleGradMaker<paddle::framework::OpDesc>,
1124- ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>)
1205+ ops::SigmoidDoubleGradMaker<paddle::imperative::OpBase>);
11251206
11261207// 3. Register Sigmoid DoubleGrad Operator
1127- REGISTER_OPERATOR(
1128- sigmoid_grad_grad,
1129- ops::ActivationOpDoubleGrad<ops::SigmoidGradFunctor<float >::FwdDeps()>,
1130- ops::ActivationDoubleGradOpInplaceInferer);
1208+ REGISTER_OPERATOR (sigmoid_grad_grad,
1209+ ops::ActivationOpDoubleGrad<ops::SigmoidGradGradFunctor<
1210+ float >::FwdDeps()>, // 应该是 SigmoidGradGradFunctor
1211+ ops::ActivationDoubleGradOpInplaceInferer,
1212+ ops::SigmoidTribleGradMaker<paddle::framework::OpDesc>,
1213+ ops::SigmoidTribleGradMaker<paddle::imperative::OpBase>);
1214+
1215+ // 4. Register Sigmoid TribleGrad Operator
1216+ REGISTER_OPERATOR (sigmoid_trible_grad,
1217+ ops::ActivationOpTribleGrad<
1218+ ops::SigmoidTribleGradFunctor<float >::FwdDeps()>,
1219+ ops::ActivationTribleGradOpInplaceInferer);
11311220
11321221// Register Sigmoid/GradSigmoid Kernels
11331222REGISTER_ACTIVATION_CPU_KERNEL (sigmoid, Sigmoid, SigmoidFunctor,
@@ -1143,6 +1232,16 @@ REGISTER_OP_CPU_KERNEL(
11431232 ops::SigmoidDoubleGradKernel<plat::CPUDeviceContext,
11441233 ops::SigmoidGradGradFunctor<plat::float16>>);
11451234
1235+ // Register TribleGrad Kernel
1236+ REGISTER_OP_CPU_KERNEL (
1237+ sigmoid_trible_grad,
1238+ ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1239+ ops::SigmoidTribleGradFunctor<float >>,
1240+ ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1241+ ops::SigmoidTribleGradFunctor<double >>,
1242+ ops::SigmoidTribleGradKernel<plat::CPUDeviceContext,
1243+ ops::SigmoidTribleGradFunctor<plat::float16>>);
1244+
11461245/* ========================================================================== */
11471246
11481247/* ========================== tanh register ============================= */
0 commit comments