Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,8 @@ void FakeInitializeOutputsForFunctionKernel(
dtype = GetInputDType(runtime_ctx, "X");
} else if (op_type == "dequantize_linear") {
dtype = GetInputDType(runtime_ctx, "Scale");
} else if (op_type == "quantize_linear") {
dtype = GetInputDType(runtime_ctx, "Scale");
} else if (op_type == "lamb") {
bool multi_precision = op.Attr<bool>("multi_precision");
dtype = GetInputDType(runtime_ctx, "Moment1");
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/operators/ops_signature/quantize_linear_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,28 @@ limitations under the License. */

namespace phi {

KernelSignature QuantizeLinearOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"quantize_linear_deprecated",
{"X", "Scale", "ZeroPoint", "InAccum", "InState"},
{"quant_axis", "bit_length", "round_type", "is_test", "only_observer"},
{"Y", "OutState", "OutAccum", "OutScale"});
}

KernelSignature DeQuantizeLinearOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature(
"dequantize_linear",
"dequantize_linear_deprecated",
{"X", "Scale", "ZeroPoint", "InAccum", "InState"},
{"quant_axis", "bit_length", "round_type", "is_test", "only_observer"},
{"Y", "OutState", "OutAccum", "OutScale"});
}

} // namespace phi

PD_REGISTER_ARG_MAPPING_FN(quantize_linear,
phi::QuantizeLinearOpArgumentMapping);

PD_REGISTER_ARG_MAPPING_FN(dequantize_linear,
phi::DeQuantizeLinearOpArgumentMapping);
159 changes: 23 additions & 136 deletions paddle/fluid/operators/quantize_linear_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,119 +9,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/quantize_linear_op.h"

#include <algorithm>
#include <string>
#include <vector>

#include "paddle/common/ddim.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"

namespace paddle {
namespace operators {

template <typename T>
struct DequantizeFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext &dev_ctx,
const phi::DenseTensor *in,
const phi::DenseTensor *scale,
T max_range,
phi::DenseTensor *out) {
auto in_e = phi::EigenVector<T>::Flatten(*in);
const T *scale_factor = scale->data<T>();
auto out_e = phi::EigenVector<T>::Flatten(*out);

auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * scale_factor[0] / max_range;
}
};

template <typename T>
struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
void operator()(const phi::CPUContext &dev_ctx,
const phi::DenseTensor *in,
const phi::DenseTensor *scale,
T max_range,
const int quant_axis,
phi::DenseTensor *out) {
// Dequant op is before quantized op
// Dequantize the weight of quantized op
auto in_dims = in->dims();
const int64_t channel = in_dims[quant_axis];
const T *scale_factor = scale->data<T>();
if (quant_axis == 0) {
for (int64_t i = 0; i < channel; i++) {
T s = scale_factor[i];
phi::DenseTensor one_channel_in = in->Slice(i, i + 1);
phi::DenseTensor one_channel_out = out->Slice(i, i + 1);
auto in_e = phi::EigenVector<T>::Flatten(one_channel_in);
auto out_e = phi::EigenVector<T>::Flatten(one_channel_out);
auto &dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * s / max_range;
}
} else if (quant_axis == 1) {
int64_t out_iter = 1;
for (int i = 0; i < quant_axis; i++) {
out_iter *= in_dims[i];
}
int64_t step_i = in->numel() / out_iter;
int64_t step_j = in->numel() / (out_iter * channel);
auto *in_data = in->data<T>();
auto *out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
for (int64_t i = 0; i < out_iter; i++) {
for (int64_t j = 0; j < channel; j++) {
auto *cur_in = in_data + i * step_i + j * step_j;
auto *cur_out = out_data + i * step_i + j * step_j;
T s = scale_factor[j];
for (int64_t k = 0; k < step_j; k++) {
*cur_out = (*cur_in) * s / max_range;
++cur_in;
++cur_out;
}
}
}
}
}
};

template struct DequantizeFunctor<phi::CPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::CPUContext, float>;
template struct DequantizeFunctor<phi::CPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext,
phi::dtype::float16>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, double>;

class QuantizeLinearOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear");
OP_INOUT_CHECK(
ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear");
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear");
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
int quant_axis = ctx->Attrs().Get<int>("quant_axis");
if (ctx->HasOutput("OutScale")) {
if (quant_axis < 0) {
ctx->SetOutputDim("OutScale", {1});
} else {
ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
}
}
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->ShareLoD("X", /*->*/ "Y");
}

protected:
phi::KernelKey GetExpectedKernelType(
Expand All @@ -130,7 +39,6 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
ctx.GetPlace());
}
};

class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -159,42 +67,16 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis.")
.SetDefault(0)
.AddCustomChecker([](const int &quant_axis) {
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1 || quant_axis == -1,
true,
phi::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
});
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
phi::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
.SetDefault(0);
AddAttr<int>("bit_length", "(int, default 8)").SetDefault(8);
AddAttr<int>(
"round_type",
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"0: rounding to nearest ties to even. Eg: round(1.5)=2, "
"round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3")
.SetDefault(0)
.AddCustomChecker([](const int &round_type) {
PADDLE_ENFORCE_EQ(
round_type == 0 || round_type == 1,
true,
phi::errors::InvalidArgument(
"'round_type' should be 0 or 1, 0 rounding to "
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d",
round_type));
});
.SetDefault(0);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
Expand All @@ -216,25 +98,30 @@ In above three formulas, the range value of c is as follow:
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

DECLARE_INFER_SHAPE_FUNCTOR(quantize_linear,
QuantizeLinearInferShapeFunctor,
PD_INFER_META(phi::QuantizeLinearInferMeta));

DECLARE_INFER_SHAPE_FUNCTOR(dequantize_linear,
DeQuantizeLinearInferShapeFunctor,
PD_INFER_META(phi::QuantizeLinearInferMeta));
REGISTER_OPERATOR(
quantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

PD_REGISTER_STRUCT_KERNEL(
quantize_linear, CPU, ALL_LAYOUT, ops::QuantizeLinearKernel, float) {}
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
QuantizeLinearInferShapeFunctor);

REGISTER_OPERATOR(
dequantize_linear,
ops::QuantizeLinearOp,
ops::QuantizeLinearOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DeQuantizeLinearInferShapeFunctor);
131 changes: 0 additions & 131 deletions paddle/fluid/operators/quantize_linear_op.cu

This file was deleted.

Loading