@@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
649649 " MovingAverageAbsMaxScale" );
650650 OP_INOUT_CHECK (ctx->HasOutput (" OutScale" ), " Output" , " OutScale" ,
651651 " MovingAverageAbsMaxScale" );
652+
652653 if (ctx->HasOutput (" OutState" )) {
653654 ctx->SetOutputDim (" OutState" , {1 });
654655 }
655656 if (ctx->HasOutput (" OutAccum" )) {
656657 ctx->SetOutputDim (" OutAccum" , {1 });
657658 }
658- ctx->SetOutputDim (" OutScale" , {1 });
659+ if (ctx->HasOutput (" Out" )) {
660+ ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
661+ ctx->SetOutputDim (" OutScale" , {1 });
662+ ctx->ShareLoD (" X" , /* ->*/ " Out" );
663+ }
659664 }
660665
661666 protected:
@@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker
673678 AddInput (" X" , " (Tensor) Input is float data type." );
674679 AddInput (" InAccum" , " Last accum." ).AsDispensable ();
675680 AddInput (" InState" , " Last state." ).AsDispensable ();
681+ AddOutput (" Out" ,
682+ " (Tensor) Output tensor is just equivalent to the input tensor." )
683+ .AsDispensable ();
676684 AddOutput (" OutScale" , " Current scale" );
677685 AddOutput (" OutState" , " (Tensor) state buffer." ).AsDispensable ();
678686 AddOutput (" OutAccum" , " (Tensor) accum buffer." ).AsDispensable ();
@@ -693,17 +701,17 @@ And it will not quantize the input tensor.
693701 }
694702};
695703
696- class FakeQuantDequantGradOp : public framework ::OperatorWithKernel {
704+ class StrightThroughEstimatorGradOp : public framework ::OperatorWithKernel {
697705 public:
698706 using framework::OperatorWithKernel::OperatorWithKernel;
699707
700708 void InferShape (framework::InferShapeContext* ctx) const override {
701709 auto out_grad_name = framework::GradVarName (" Out" );
702710 auto x_grad_name = framework::GradVarName (" X" );
703711 OP_INOUT_CHECK (ctx->HasInput (out_grad_name), " Input" , out_grad_name,
704- " FakeQuantDequantGradOp " );
712+ " StrightThroughEstimatorGradOp " );
705713 OP_INOUT_CHECK (ctx->HasOutput (x_grad_name), " Output" , x_grad_name,
706- " FakeQuantDequantGradOp " );
714+ " StrightThroughEstimatorGradOp " );
707715
708716 ctx->SetOutputDim (x_grad_name, ctx->GetInputDim (out_grad_name));
709717 }
@@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
717725};
718726
719727template <typename T>
720- class FakeQuantDequantGradMaker : public framework ::SingleGradOpMaker<T> {
728+ class StrightThroughEstimatorMaker : public framework ::SingleGradOpMaker<T> {
721729 public:
722730 using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
723731
724732 protected:
725733 void Apply (GradOpPtr<T> grad_op) const override {
726- grad_op->SetType (" fake_quantize_dequantize_grad " );
734+ grad_op->SetType (" stright_throuth_estimator_grad " );
727735 grad_op->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
728736 grad_op->SetOutput (framework::GradVarName (" X" ), this ->InputGrad (" X" ));
729737 grad_op->SetAttrMap (this ->Attrs ());
@@ -744,11 +752,11 @@ REGISTER_OPERATOR(
744752REGISTER_OP_CPU_KERNEL (fake_quantize_abs_max,
745753 ops::FakeQuantizeAbsMaxKernel<CPU, float >);
746754
747- REGISTER_OPERATOR (fake_quantize_dequantize_abs_max,
748- ops::FakeQuantOrWithDequantAbsMaxOp,
749- ops::FakeQuantOrWithDequantAbsMaxOpMaker,
750- ops::FakeQuantDequantGradMaker <paddle::framework::OpDesc>,
751- ops::FakeQuantDequantGradMaker <paddle::imperative::OpBase>);
755+ REGISTER_OPERATOR (
756+ fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
757+ ops::FakeQuantOrWithDequantAbsMaxOpMaker,
758+ ops::StrightThroughEstimatorMaker <paddle::framework::OpDesc>,
759+ ops::StrightThroughEstimatorMaker <paddle::imperative::OpBase>);
752760REGISTER_OP_CPU_KERNEL (fake_quantize_dequantize_abs_max,
753761 ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float >);
754762
@@ -769,11 +777,12 @@ REGISTER_OPERATOR(
769777REGISTER_OP_CPU_KERNEL (fake_quantize_moving_average_abs_max,
770778 ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float >);
771779
772- REGISTER_OPERATOR (fake_quantize_dequantize_moving_average_abs_max,
773- ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
774- ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
775- ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
776- ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
780+ REGISTER_OPERATOR (
781+ fake_quantize_dequantize_moving_average_abs_max,
782+ ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
783+ ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
784+ ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
785+ ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
777786REGISTER_OP_CPU_KERNEL (
778787 fake_quantize_dequantize_moving_average_abs_max,
779788 ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float >);
@@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
789798REGISTER_OPERATOR (
790799 moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
791800 ops::MovingAverageAbsMaxScaleOpMaker,
792- paddle::framework::EmptyGradOpMaker <paddle::framework::OpDesc>,
793- paddle::framework::EmptyGradOpMaker <paddle::imperative::OpBase>);
801+ ops::StrightThroughEstimatorMaker <paddle::framework::OpDesc>,
802+ ops::StrightThroughEstimatorMaker <paddle::imperative::OpBase>);
794803REGISTER_OP_CPU_KERNEL (moving_average_abs_max_scale,
795804 ops::MovingAverageAbsMaxScaleKernel<CPU, float >);
796805
797- REGISTER_OPERATOR (fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
798- REGISTER_OP_CPU_KERNEL (fake_quantize_dequantize_grad,
799- ops::FakeQuantDequantGradKernel<CPU, float >);
806+ REGISTER_OPERATOR (stright_throuth_estimator_grad,
807+ ops::StrightThroughEstimatorGradOp);
808+ REGISTER_OP_CPU_KERNEL (stright_throuth_estimator_grad,
809+ ops::StrightThroughEstimatorGradKernel<CPU, float >);
800810
801- REGISTER_OPERATOR (fake_channel_wise_quantize_dequantize_abs_max,
802- ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
803- ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
804- ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
805- ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
811+ REGISTER_OPERATOR (
812+ fake_channel_wise_quantize_dequantize_abs_max,
813+ ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
814+ ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
815+ ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
816+ ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
806817REGISTER_OP_CPU_KERNEL (
807818 fake_channel_wise_quantize_dequantize_abs_max,
808819 ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float >);
@@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
820831 " Out" ,
821832 " Delete output in order to make the inference model not "
822833 " save moving_average_abs_max_scale operator. This will "
823- " make the quantitative model be correctly applied in inference." ));
834+ " make the quantitative model be correctly applied in inference." ))
835+ .AddCheckpoint(
836+ R"ROC( Incompatible upgrade of output [Out])ROC" ,
837+ paddle::framework::compatible::OpVersionDesc ().NewOutput(
838+ " Out" , " In order to support dygraph qat, add output again." ));
0 commit comments