Skip to content

Commit

Permalink
improve error messages of fake_dequantize_op. test=develop (PaddlePad…
Browse files Browse the repository at this point in the history
  • Loading branch information
wzzju authored Apr 9, 2020
1 parent 467ce0b commit 1cf64e0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
22 changes: 9 additions & 13 deletions paddle/fluid/operators/fake_dequantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeDequantizeMaxAbs");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeDequantizeMaxAbs");

ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
Expand Down Expand Up @@ -125,15 +124,12 @@ class FakeChannelWiseDequantizeMaxAbsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasInputs("Scales"),
"Input(Scales) of FakeChannelWiseDequantizeMaxAbsOp "
"should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FakeChannelWiseDequantizeMaxAbs");
OP_INOUT_CHECK(ctx->HasInputs("Scales"), "Input", "Scales",
"FakeChannelWiseDequantizeMaxAbs");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeChannelWiseDequantizeMaxAbs");

ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
Expand Down
24 changes: 15 additions & 9 deletions paddle/fluid/operators/fake_dequantize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,25 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
if (scale_num == 1) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[0],
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only one "
"element.");
platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with "
"first dimension value of Input(X) when the `Scales` has only "
"one element, but %ld != %ld here.",
scales[0]->numel(), in->dims()[0]));
max_range *= (std::pow(2, quant_bits[0] - 1) - 1);
} else if (scale_num == 2) {
PADDLE_ENFORCE_EQ(
scales[0]->numel(), in->dims()[1],
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements.");
PADDLE_ENFORCE_EQ(
scales[1]->numel(), 1,
"The second scale tensor should only have one value at now.");
platform::errors::PreconditionNotMet(
"The number of first scale values must be the same with "
"second dimension value of Input(X) when the `Scales` has two "
"elements, but %ld != %ld here.",
scales[0]->numel(), in->dims()[1]));
PADDLE_ENFORCE_EQ(scales[1]->numel(), 1,
platform::errors::PreconditionNotMet(
"The second scale tensor should only have one "
"value at now, but it has %ld values here.",
scales[1]->numel()));
max_range *= (std::pow(2, quant_bits[0] - 1) - 1) *
(std::pow(2, quant_bits[1] - 1) - 1);
}
Expand Down

0 comments on commit 1cf64e0

Please sign in to comment.