Skip to content

Commit a668849

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into trt_affine_channels_op
2 parents 99c286a + b47478e commit a668849

File tree

17 files changed

+1038
-898
lines changed

17 files changed

+1038
-898
lines changed

paddle/fluid/framework/section_worker.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ void SectionWorker::RunForward(
3939
int op_role = op->Attr<int>(std::string("op_role"));
4040
// We run op with op_role = kLRSched only for the first microbatch
4141
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
42-
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
43-
op_role == (static_cast<int>(OpRole::kForward) |
44-
static_cast<int>(OpRole::kLoss)) ||
45-
op_role == static_cast<int>(OpRole::kLRSched);
46-
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
47-
op_role == (static_cast<int>(OpRole::kForward) |
48-
static_cast<int>(OpRole::kLoss));
42+
bool run_first_mbatch = (op_role == static_cast<int>(OpRole::kForward)) ||
43+
(op_role == (static_cast<int>(OpRole::kForward) |
44+
static_cast<int>(OpRole::kLoss))) ||
45+
(op_role == static_cast<int>(OpRole::kLRSched));
46+
bool run_others = (op_role == static_cast<int>(OpRole::kForward)) ||
47+
(op_role == (static_cast<int>(OpRole::kForward) |
48+
static_cast<int>(OpRole::kLoss)));
4949
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
5050
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
5151
<< micro_id;
@@ -64,9 +64,9 @@ void SectionWorker::RunBackward(
6464
&unused_vars_) {
6565
for (auto &op : ops_) {
6666
int op_role = op->Attr<int>(std::string("op_role"));
67-
if (op_role == static_cast<int>(OpRole::kBackward) ||
68-
op_role == (static_cast<int>(OpRole::kBackward) |
69-
static_cast<int>(OpRole::kLoss))) {
67+
if ((op_role == static_cast<int>(OpRole::kBackward)) ||
68+
(op_role == (static_cast<int>(OpRole::kBackward) |
69+
static_cast<int>(OpRole::kLoss)))) {
7070
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
7171
<< micro_id;
7272
op->Run(*microbatch_scopes_[micro_id], place_);

paddle/fluid/operators/fake_quantize_op.cc

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -649,13 +649,18 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
649649
"MovingAverageAbsMaxScale");
650650
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
651651
"MovingAverageAbsMaxScale");
652+
652653
if (ctx->HasOutput("OutState")) {
653654
ctx->SetOutputDim("OutState", {1});
654655
}
655656
if (ctx->HasOutput("OutAccum")) {
656657
ctx->SetOutputDim("OutAccum", {1});
657658
}
658-
ctx->SetOutputDim("OutScale", {1});
659+
if (ctx->HasOutput("Out")) {
660+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
661+
ctx->SetOutputDim("OutScale", {1});
662+
ctx->ShareLoD("X", /*->*/ "Out");
663+
}
659664
}
660665

661666
protected:
@@ -673,6 +678,9 @@ class MovingAverageAbsMaxScaleOpMaker
673678
AddInput("X", "(Tensor) Input is float data type.");
674679
AddInput("InAccum", "Last accum.").AsDispensable();
675680
AddInput("InState", "Last state.").AsDispensable();
681+
AddOutput("Out",
682+
"(Tensor) Output tensor is just equivalent to the input tensor.")
683+
.AsDispensable();
676684
AddOutput("OutScale", " Current scale");
677685
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
678686
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
@@ -693,17 +701,17 @@ And it will not quantize the input tensor.
693701
}
694702
};
695703

696-
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
704+
class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
697705
public:
698706
using framework::OperatorWithKernel::OperatorWithKernel;
699707

700708
void InferShape(framework::InferShapeContext* ctx) const override {
701709
auto out_grad_name = framework::GradVarName("Out");
702710
auto x_grad_name = framework::GradVarName("X");
703711
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
704-
"FakeQuantDequantGradOp");
712+
"StrightThroughEstimatorGradOp");
705713
OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
706-
"FakeQuantDequantGradOp");
714+
"StrightThroughEstimatorGradOp");
707715

708716
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
709717
}
@@ -717,13 +725,13 @@ class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
717725
};
718726

719727
template <typename T>
720-
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
728+
class StrightThroughEstimatorMaker : public framework::SingleGradOpMaker<T> {
721729
public:
722730
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
723731

724732
protected:
725733
void Apply(GradOpPtr<T> grad_op) const override {
726-
grad_op->SetType("fake_quantize_dequantize_grad");
734+
grad_op->SetType("stright_throuth_estimator_grad");
727735
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
728736
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
729737
grad_op->SetAttrMap(this->Attrs());
@@ -744,11 +752,11 @@ REGISTER_OPERATOR(
744752
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
745753
ops::FakeQuantizeAbsMaxKernel<CPU, float>);
746754

747-
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
748-
ops::FakeQuantOrWithDequantAbsMaxOp,
749-
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
750-
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
751-
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
755+
REGISTER_OPERATOR(
756+
fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
757+
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
758+
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
759+
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
752760
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
753761
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
754762

@@ -769,11 +777,12 @@ REGISTER_OPERATOR(
769777
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
770778
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
771779

772-
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
773-
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
774-
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
775-
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
776-
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
780+
REGISTER_OPERATOR(
781+
fake_quantize_dequantize_moving_average_abs_max,
782+
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
783+
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
784+
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
785+
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
777786
REGISTER_OP_CPU_KERNEL(
778787
fake_quantize_dequantize_moving_average_abs_max,
779788
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
@@ -789,20 +798,22 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
789798
REGISTER_OPERATOR(
790799
moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
791800
ops::MovingAverageAbsMaxScaleOpMaker,
792-
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
793-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
801+
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
802+
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
794803
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
795804
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
796805

797-
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
798-
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
799-
ops::FakeQuantDequantGradKernel<CPU, float>);
806+
REGISTER_OPERATOR(stright_throuth_estimator_grad,
807+
ops::StrightThroughEstimatorGradOp);
808+
REGISTER_OP_CPU_KERNEL(stright_throuth_estimator_grad,
809+
ops::StrightThroughEstimatorGradKernel<CPU, float>);
800810

801-
REGISTER_OPERATOR(fake_channel_wise_quantize_dequantize_abs_max,
802-
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
803-
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
804-
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
805-
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
811+
REGISTER_OPERATOR(
812+
fake_channel_wise_quantize_dequantize_abs_max,
813+
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp,
814+
ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker,
815+
ops::StrightThroughEstimatorMaker<paddle::framework::OpDesc>,
816+
ops::StrightThroughEstimatorMaker<paddle::imperative::OpBase>);
806817
REGISTER_OP_CPU_KERNEL(
807818
fake_channel_wise_quantize_dequantize_abs_max,
808819
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CPU, float>);
@@ -820,4 +831,8 @@ REGISTER_OP_VERSION(moving_average_abs_max_scale)
820831
"Out",
821832
"Delete output in order to make the inference model not "
822833
"save moving_average_abs_max_scale operator. This will "
823-
"make the quantitative model be correctly applied in inference."));
834+
"make the quantitative model be correctly applied in inference."))
835+
.AddCheckpoint(
836+
R"ROC(Incompatible upgrade of output [Out])ROC",
837+
paddle::framework::compatible::OpVersionDesc().NewOutput(
838+
"Out", "In order to support dygraph qat, add output again."));

paddle/fluid/operators/fake_quantize_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,8 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
543543
REGISTER_OP_CUDA_KERNEL(
544544
fake_quantize_dequantize_moving_average_abs_max,
545545
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
546-
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
547-
ops::FakeQuantDequantGradKernel<CUDA, float>);
546+
REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad,
547+
ops::StrightThroughEstimatorGradKernel<CUDA, float>);
548548
REGISTER_OP_CUDA_KERNEL(
549549
fake_channel_wise_quantize_dequantize_abs_max,
550550
ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);

paddle/fluid/operators/fake_quantize_op.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,12 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
314314
auto* in = context.Input<framework::Tensor>("X");
315315
auto& dev_ctx = context.template device_context<DeviceContext>();
316316

317+
if (context.HasOutput("Out")) {
318+
auto* out = context.Output<framework::Tensor>("Out");
319+
out->mutable_data<T>(context.GetPlace());
320+
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
321+
}
322+
317323
bool is_test = context.Attr<bool>("is_test");
318324
// testing
319325
if (is_test) {
@@ -344,17 +350,17 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
344350
};
345351

346352
template <typename DeviceContext, typename T>
347-
class FakeQuantDequantGradKernel : public framework::OpKernel<T> {
353+
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
348354
public:
349355
void Compute(const framework::ExecutionContext& context) const override {
350356
auto* d_out =
351357
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
352358
auto x_grad_name = framework::GradVarName("X");
353359
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name);
354-
PADDLE_ENFORCE_NOT_NULL(
355-
d_x, platform::errors::PreconditionNotMet(
356-
"FakeQuantDequantGradOp doesn't have the output named %s.",
357-
x_grad_name));
360+
PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet(
361+
"StrightThroughEstimatorGradKernel "
362+
"doesn't have the output named %s.",
363+
x_grad_name));
358364

359365
// Initialize dx as same as d_out
360366
d_x->mutable_data<T>(context.GetPlace());

paddle/fluid/pybind/op_function_generator.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
8484
{"matrix_nms", {"Out", "Index", "RoisNum"}},
8585
{"distribute_fpn_proposals",
8686
{"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}},
87-
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
87+
{"moving_average_abs_max_scale",
88+
{"Out", "OutScale", "OutAccum", "OutState"}},
8889
{"multiclass_nms3", {"Out", "NmsRoisNum"}},
8990
{"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
9091
{"momentum", {"ParamOut", "VelocityOut"}},
@@ -137,7 +138,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
137138
{"check_finite_and_unscale", {"Out", "FoundInfinite"}},
138139
{"update_loss_scaling",
139140
{"Out", "LossScaling", "OutGoodSteps", "OutBadSteps"}},
140-
{"moving_average_abs_max_scale", {"OutScale", "OutAccum", "OutState"}},
141+
{"moving_average_abs_max_scale",
142+
{"Out", "OutScale", "OutAccum", "OutState"}},
141143
{"lamb",
142144
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
143145
{"rnn", {"DropoutState"}},

python/paddle/distributed/fleet/meta_optimizers/common.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def is_optimizer_op(op):
4747

4848

4949
class CollectiveHelper(object):
50-
def __init__(self, role_maker, nrings=1, wait_port='6174'):
50+
def __init__(self, role_maker, nrings=1, wait_port=True):
5151
self.nrings = nrings
5252
self.wait_port = wait_port
5353
self.role_maker = role_maker
@@ -65,14 +65,48 @@ def update_startup_program(self, startup_program=None):
6565
self.role_maker._worker_index(), ring_id, self.wait_port)
6666
self._broadcast_params()
6767

68-
def _init_communicator(self, program, current_endpoint, endpoints, rank,
69-
ring_id, wait_port):
68+
def _init_communicator(self,
69+
program,
70+
current_endpoint,
71+
endpoints,
72+
rank,
73+
ring_id,
74+
wait_port,
75+
global_ring_id=None,
76+
sync=True):
7077
nranks = len(endpoints)
7178
other_endpoints = endpoints[:]
7279
other_endpoints.remove(current_endpoint)
7380
if rank == 0 and wait_port:
7481
wait_server_ready(other_endpoints)
7582

83+
def _add_sync_by_allreduce(block):
84+
sync_var = block.create_var(
85+
name=unique_name.generate('sync_var'),
86+
dtype=core.VarDesc.VarType.INT32,
87+
persistable=False,
88+
stop_gradient=True)
89+
block.append_op(
90+
type='fill_constant',
91+
inputs={},
92+
outputs={'Out': [sync_var]},
93+
attrs={
94+
'shape': [1],
95+
'dtype': sync_var.dtype,
96+
'value': 1,
97+
'force_cpu': False,
98+
OP_ROLE_KEY: OpRole.Forward
99+
})
100+
block.append_op(
101+
type='c_allreduce_sum',
102+
inputs={'X': [sync_var]},
103+
outputs={'Out': [sync_var]},
104+
attrs={
105+
'ring_id': global_ring_id,
106+
'use_calc_stream': True,
107+
OP_ROLE_KEY: OpRole.Forward
108+
})
109+
76110
block = program.global_block()
77111
if core.is_compiled_with_cuda():
78112
comm_id_var = block.create_var(
@@ -128,6 +162,7 @@ def _init_communicator(self, program, current_endpoint, endpoints, rank,
128162
raise ValueError(
129163
"comm_id must be generated in paddlepaddle-xpu or paddlepaddle-xpu."
130164
)
165+
if sync: _add_sync_by_allreduce(block)
131166

132167
def _wait(self, current_endpoint, endpoints):
133168
assert (self.wait_port)

0 commit comments

Comments
 (0)