diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index ac58f499e91caf..d639d212b0aa90 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -876,6 +876,8 @@ void FakeInitializeOutputsForFunctionKernel( dtype = GetInputDType(runtime_ctx, "X"); } else if (op_type == "dequantize_linear") { dtype = GetInputDType(runtime_ctx, "Scale"); + } else if (op_type == "quantize_linear") { + dtype = GetInputDType(runtime_ctx, "Scale"); } else if (op_type == "lamb") { bool multi_precision = op.Attr("multi_precision"); dtype = GetInputDType(runtime_ctx, "Moment1"); diff --git a/paddle/fluid/operators/ops_signature/quantize_linear_sig.cc b/paddle/fluid/operators/ops_signature/quantize_linear_sig.cc index 75e523bf55367d..b35fe7a89113d6 100644 --- a/paddle/fluid/operators/ops_signature/quantize_linear_sig.cc +++ b/paddle/fluid/operators/ops_signature/quantize_linear_sig.cc @@ -16,10 +16,19 @@ limitations under the License. */ namespace phi { +KernelSignature QuantizeLinearOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature( + "quantize_linear_deprecated", + {"X", "Scale", "ZeroPoint", "InAccum", "InState"}, + {"quant_axis", "bit_length", "round_type", "is_test", "only_observer"}, + {"Y", "OutState", "OutAccum", "OutScale"}); +} + KernelSignature DeQuantizeLinearOpArgumentMapping( const ArgumentMappingContext& ctx UNUSED) { return KernelSignature( - "dequantize_linear", + "dequantize_linear_deprecated", {"X", "Scale", "ZeroPoint", "InAccum", "InState"}, {"quant_axis", "bit_length", "round_type", "is_test", "only_observer"}, {"Y", "OutState", "OutAccum", "OutScale"}); @@ -27,5 +36,8 @@ KernelSignature DeQuantizeLinearOpArgumentMapping( } // namespace phi +PD_REGISTER_ARG_MAPPING_FN(quantize_linear, + phi::QuantizeLinearOpArgumentMapping); + PD_REGISTER_ARG_MAPPING_FN(dequantize_linear, phi::DeQuantizeLinearOpArgumentMapping); diff --git a/paddle/fluid/operators/quantize_linear_op.cc b/paddle/fluid/operators/quantize_linear_op.cc index 1ccd8496c85eaa..71ef66355da558 100644 --- a/paddle/fluid/operators/quantize_linear_op.cc +++ b/paddle/fluid/operators/quantize_linear_op.cc @@ -9,119 +9,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/quantize_linear_op.h" - #include #include #include +#include "paddle/common/ddim.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/common/transform.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/impl/clip_kernel_impl.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/multiary.h" + namespace paddle { namespace operators { - -template -struct DequantizeFunctor { - void operator()(const phi::CPUContext &dev_ctx, - const phi::DenseTensor *in, - const phi::DenseTensor *scale, - T max_range, - phi::DenseTensor *out) { - auto in_e = phi::EigenVector::Flatten(*in); - const T *scale_factor = scale->data(); - auto out_e = phi::EigenVector::Flatten(*out); - - auto &dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * scale_factor[0] / max_range; - } -}; - -template -struct ChannelDequantizeFunctorV2 { - void operator()(const phi::CPUContext &dev_ctx, - const phi::DenseTensor *in, - const phi::DenseTensor *scale, - T max_range, - const int quant_axis, - phi::DenseTensor *out) { - // Dequant op is before quantized op - // Dequantize the weight of quantized op - auto in_dims = in->dims(); - const int64_t channel = in_dims[quant_axis]; - const T *scale_factor = scale->data(); - if (quant_axis == 0) { - for (int64_t i = 0; i < channel; i++) { - T s = scale_factor[i]; - phi::DenseTensor one_channel_in = in->Slice(i, i + 1); - phi::DenseTensor one_channel_out = out->Slice(i, i + 1); - auto in_e = phi::EigenVector::Flatten(one_channel_in); - auto out_e = phi::EigenVector::Flatten(one_channel_out); - auto &dev = *dev_ctx.eigen_device(); - out_e.device(dev) = in_e * s / max_range; - } - } else if (quant_axis == 1) { - int64_t out_iter = 1; - for (int i = 0; i < quant_axis; i++) { - out_iter *= in_dims[i]; - } - int64_t step_i = in->numel() / out_iter; - int64_t step_j = in->numel() / (out_iter * channel); - auto *in_data = in->data(); - auto *out_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - for (int64_t i = 0; i < out_iter; i++) { - for (int64_t j = 0; j < channel; j++) { - auto *cur_in = in_data + i * step_i + j * step_j; - auto *cur_out = out_data + i * step_i + j * step_j; - T s = scale_factor[j]; - for (int64_t k = 0; k < step_j; k++) { - *cur_out = (*cur_in) * s / max_range; - ++cur_in; - ++cur_out; - } - } - } - } - } -}; - -template struct DequantizeFunctor; -template struct DequantizeFunctor; -template struct DequantizeFunctor; -template struct ChannelDequantizeFunctorV2; -template struct ChannelDequantizeFunctorV2; -template struct ChannelDequantizeFunctorV2; - class QuantizeLinearOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear"); - OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear"); - OP_INOUT_CHECK( - ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear"); - OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear"); - ctx->SetOutputDim("Y", ctx->GetInputDim("X")); - int quant_axis = ctx->Attrs().Get("quant_axis"); - if (ctx->HasOutput("OutScale")) { - if (quant_axis < 0) { - ctx->SetOutputDim("OutScale", {1}); - } else { - ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]}); - } - } - if (ctx->HasOutput("OutState")) { - ctx->SetOutputDim("OutState", {1}); - } - if (ctx->HasOutput("OutAccum")) { - ctx->SetOutputDim("OutAccum", {1}); - } - ctx->ShareLoD("X", /*->*/ "Y"); - } protected: phi::KernelKey GetExpectedKernelType( @@ -130,7 +39,6 @@ class QuantizeLinearOp : public framework::OperatorWithKernel { ctx.GetPlace()); } }; - class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -159,42 +67,16 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker { "(int, default 0) The axis for quantization. " "For conv2d, depthwise_conv2d, conv2d_transpose " "and mul, the quant_axis is equal to the cout axis.") - .SetDefault(0) - .AddCustomChecker([](const int &quant_axis) { - PADDLE_ENFORCE_EQ( - quant_axis == 0 || quant_axis == 1 || quant_axis == -1, - true, - phi::errors::InvalidArgument("'quant_axis' should be 0 or 1, but " - "the received is %d", - quant_axis)); - }); - AddAttr("bit_length", "(int, default 8)") - .SetDefault(8) - .AddCustomChecker([](const int &bit_length) { - PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, - true, - phi::errors::InvalidArgument( - "'bit_length' should be between 1 and 16, but " - "the received is %d", - bit_length)); - }); + .SetDefault(0); + AddAttr("bit_length", "(int, default 8)").SetDefault(8); AddAttr( "round_type", "(int, default 0) The round type of fp32 to int." - "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2" + "0: rounding to nearest ties to even. Eg: round(1.5)=2, " + "round(2.5)=2" "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, " "round(2.5)=3") - .SetDefault(0) - .AddCustomChecker([](const int &round_type) { - PADDLE_ENFORCE_EQ( - round_type == 0 || round_type == 1, - true, - phi::errors::InvalidArgument( - "'round_type' should be 0 or 1, 0 rounding to " - "nearest ties to even and 1 is rounding to nearest " - "ties away from zero.but the received is %d", - round_type)); - }); + .SetDefault(0); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") @@ -216,25 +98,30 @@ In above three formulas, the range value of c is as follow: )DOC"); } }; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(quantize_linear, + QuantizeLinearInferShapeFunctor, + PD_INFER_META(phi::QuantizeLinearInferMeta)); + +DECLARE_INFER_SHAPE_FUNCTOR(dequantize_linear, + DeQuantizeLinearInferShapeFunctor, + PD_INFER_META(phi::QuantizeLinearInferMeta)); REGISTER_OPERATOR( quantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - quantize_linear, CPU, ALL_LAYOUT, ops::QuantizeLinearKernel, float) {} + paddle::framework::EmptyGradOpMaker, + QuantizeLinearInferShapeFunctor); REGISTER_OPERATOR( dequantize_linear, ops::QuantizeLinearOp, ops::QuantizeLinearOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + DeQuantizeLinearInferShapeFunctor); diff --git a/paddle/fluid/operators/quantize_linear_op.cu b/paddle/fluid/operators/quantize_linear_op.cu deleted file mode 100644 index d9aa1a860f4057..00000000000000 --- a/paddle/fluid/operators/quantize_linear_op.cu +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/fake_quantize_op.cu.h" -#include "paddle/fluid/operators/quantize_linear_op.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" - -using float16 = phi::dtype::float16; - -namespace paddle { -namespace operators { - -template -__global__ void KeDequantize( - const T* in, const T* scale, T max_range, int64_t num, T* out) { - int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - out[i] = in[i] * scale[0] / max_range; - } -} - -template -__global__ void DequantizeOneScaleQuantAxisN(const T* in, - const T* scale, - const T max_range, - const int64_t num, - const int n_scales, - const int quant_stride, - T* out) { - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % n_scales]; - out[i] = in[i] * s / max_range; - } -} - -template -struct DequantizeFunctor { - void operator()(const phi::GPUContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor* scale, - T max_range, - phi::DenseTensor* out) { - const T* in_data = in->data(); - const T* scale_factor = scale->data(); - T* out_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - - int64_t num = in->numel(); - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = - std::max(((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - KeDequantize<<>>( - in_data, scale_factor, max_range, num, out_data); - } -}; - -template -struct ChannelDequantizeFunctorV2 { - void operator()(const phi::GPUContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor* scale, - T max_range, - const int quant_axis, - phi::DenseTensor* out) { - auto in_dims = in->dims(); - const T* in_data = in->data(); - T* out_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - int64_t num = in->numel(); - const T* scale_factor = scale->data(); - int64_t block_size = std::min( - num, static_cast(dev_ctx.GetMaxThreadsPerBlock() / 4)); - int64_t max_threads = - dev_ctx.GetMaxPhysicalThreadCount(); // SM * block_per_SM - const int64_t max_blocks = - std::max(((max_threads - 1) / block_size + 1), static_cast(1)); - const int64_t grid_size = - std::min(max_blocks, (num + block_size - 1) / block_size); - - int quant_stride = 1; - for (int i = quant_axis + 1; i < in_dims.size(); i++) { - quant_stride *= in_dims[i]; - } - - DequantizeOneScaleQuantAxisN - <<>>(in_data, - scale_factor, - max_range, - num, - in_dims[quant_axis], - quant_stride, - out_data); - } -}; - -template struct DequantizeFunctor; -template struct DequantizeFunctor; -template struct DequantizeFunctor; -template struct ChannelDequantizeFunctorV2; -template struct ChannelDequantizeFunctorV2; -template struct ChannelDequantizeFunctorV2; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(quantize_linear, - GPU, - ALL_LAYOUT, - ops::QuantizeLinearKernel, - float, - float16) {} diff --git a/paddle/fluid/operators/quantize_linear_op.h b/paddle/fluid/operators/quantize_linear_op.h deleted file mode 100644 index cd30ab6186c3a5..00000000000000 --- a/paddle/fluid/operators/quantize_linear_op.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/common/ddim.h" -#include "paddle/common/hostdevice.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/fake_quantize_op.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/transform.h" -#include "paddle/phi/kernels/cast_kernel.h" - -namespace paddle { -namespace operators { - -template -struct DequantizeFunctor { - void operator()(const DeviceContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor* scale, - T max_range, - phi::DenseTensor* out); -}; - -template -struct ChannelDequantizeFunctorV2 { - void operator()(const DeviceContext& dev_ctx, - const phi::DenseTensor* in, - const phi::DenseTensor** scales, - const int scale_num, - T max_range, - const int quant_axis, - phi::DenseTensor* out); -}; - -template -class QuantizeLinearKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* in_scale = context.Input("Scale"); - - auto* out = context.Output("Y"); - out->mutable_data(context.GetPlace()); - int bit_length = context.Attr("bit_length"); - int round_type = context.Attr("round_type"); - int bin_cnt = std::pow(2, bit_length - 1) - 1; - int quant_axis = context.Attr("quant_axis"); - bool is_test = context.Attr("is_test"); - bool only_observer = context.Attr("only_observer"); - auto& dev_ctx = context.template device_context(); - - if (quant_axis < 0) { - if (!is_test) { - // training - auto* in_accum = context.Input("InAccum"); - auto* in_state = context.Input("InState"); - phi::DenseTensor tmp_scale; - tmp_scale.Resize(common::make_dim(1)); - T* cur_scale_data = dev_ctx.template Alloc(&tmp_scale); - - phi::funcs::FindAbsMaxFunctor()( - dev_ctx, in->data(), in->numel(), cur_scale_data); - - auto* out_state = context.Output("OutState"); - auto* out_accum = context.Output("OutAccum"); - auto* out_scale = context.Output("OutScale"); - out_state->mutable_data(context.GetPlace()); - out_accum->mutable_data(context.GetPlace()); - out_scale->mutable_data(context.GetPlace()); - float moving_rate = context.Attr("moving_rate"); - - phi::funcs::FindMovingAverageAbsMaxFunctor()( - dev_ctx, - *in_accum, - *in_state, - cur_scale_data, - moving_rate, - out_state, - out_accum, - out_scale); - if (only_observer) { - framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); - } else { - phi::funcs::ClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, round_type, out); - } - } else { - if (only_observer) { - framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); - } else { - phi::funcs::ClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, round_type, out); - } - } - } else { - if (!is_test) { - auto* out_scale = context.Output("OutScale"); - T* out_scale_data = out_scale->mutable_data(context.GetPlace()); - phi::funcs::FindChannelAbsMaxFunctor()( - dev_ctx, *in, quant_axis, out_scale_data); - if (only_observer) { - framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); - } else { - phi::funcs::ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); - } - } else { - if (only_observer) { - framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); - } else { - phi::funcs::ChannelClipAndFakeQuantFunctor()( - dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); - } - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 19d326736fec43..8d653ae09c139c 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -83,10 +83,6 @@ const std::unordered_set LegacyOpList = { LrnGradOp::name(), MovingAverageAbsMaxScaleOp::name(), MovingAverageAbsMaxScale_Op::name(), - QuantizeLinearOp::name(), - QuantizeLinear_Op::name(), - DequantizeLinearOp::name(), - DequantizeLinear_Op::name(), #ifdef PADDLE_WITH_DNNL paddle::onednn::dialect::LrnOp::name(), paddle::onednn::dialect::LrnGradOp::name(), diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index aa32697afbc78c..88c8612e432e86 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3844,13 +3844,37 @@ void PsroiPoolInferMeta(const MetaTensor& x, void QuantizeLinearInferMeta(const MetaTensor& x, const MetaTensor& scale, + const MetaTensor& zero_point, const MetaTensor& in_accum, const MetaTensor& in_state, int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, MetaTensor* y, - MetaTensor* out_scale, + MetaTensor* out_state, MetaTensor* out_accum, - MetaTensor* out_state) { + MetaTensor* out_scale) { + PADDLE_ENFORCE_EQ( + quant_axis == 0 || quant_axis == 1 || quant_axis == -1, + true, + phi::errors::InvalidArgument("'quant_axis' should be 0, 1 or -1, but " + "the received is %d", + quant_axis)); + PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, + true, + phi::errors::InvalidArgument( + "'bit_length' should be between 1 and 16, but " + "the received is %d", + bit_length)); + PADDLE_ENFORCE_EQ(round_type == 0 || round_type == 1, + true, + phi::errors::InvalidArgument( + "'round_type' should be 0 or 1, 0 rounding to " + "nearest ties to even and 1 is rounding to nearest " + "ties away from zero.but the received is %d", + round_type)); y->set_dims(x.dims()); y->share_lod(x); if (out_scale) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index be50b08fe56e2f..212a87954c2bd8 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -711,13 +711,18 @@ void PsroiPoolInferMeta(const MetaTensor& x, void QuantizeLinearInferMeta(const MetaTensor& x, const MetaTensor& scale, + const MetaTensor& zero_point, const MetaTensor& in_accum, const MetaTensor& in_state, int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, MetaTensor* y, - MetaTensor* out_scale, + MetaTensor* out_state, MetaTensor* out_accum, - MetaTensor* out_state); + MetaTensor* out_scale); void RAdamInferMeta(const MetaTensor& param, const MetaTensor& grad, diff --git a/paddle/phi/kernels/cpu/quantize_linear_kernel.cc b/paddle/phi/kernels/cpu/quantize_linear_kernel.cc index a7f3954407a526..16d4eab1c37b50 100644 --- a/paddle/phi/kernels/cpu/quantize_linear_kernel.cc +++ b/paddle/phi/kernels/cpu/quantize_linear_kernel.cc @@ -98,6 +98,9 @@ template struct ChannelDequantizeFunctorV2; } // namespace phi +PD_REGISTER_KERNEL( + quantize_linear, CPU, ALL_LAYOUT, phi::QuantizeLinearKernel, float) {} + PD_REGISTER_KERNEL(dequantize_linear, CPU, ALL_LAYOUT, @@ -107,3 +110,19 @@ PD_REGISTER_KERNEL(dequantize_linear, double) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } + +PD_REGISTER_KERNEL(quantize_linear_deprecated, + CPU, + ALL_LAYOUT, + phi::QuantizeLinearDeprecatedKernel, + float) {} + +PD_REGISTER_KERNEL(dequantize_linear_deprecated, + CPU, + ALL_LAYOUT, + phi::DeQuantizeLinearDeprecatedKernel, + float, + int8_t, + double) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/gpu/quantize_linear_kernel.cu b/paddle/phi/kernels/gpu/quantize_linear_kernel.cu index 11c043e76f464e..b8782efce04eea 100644 --- a/paddle/phi/kernels/gpu/quantize_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/quantize_linear_kernel.cu @@ -128,3 +128,32 @@ PD_REGISTER_KERNEL(dequantize_linear, phi::dtype::float16) { kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } + +PD_REGISTER_KERNEL(quantize_linear, + GPU, + ALL_LAYOUT, + phi::QuantizeLinearKernel, + float, + phi::dtype::float16) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} + +PD_REGISTER_KERNEL(dequantize_linear_deprecated, + GPU, + ALL_LAYOUT, + phi::DeQuantizeLinearDeprecatedKernel, + float, + int8_t, + double, + phi::dtype::float16) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} + +PD_REGISTER_KERNEL(quantize_linear_deprecated, + GPU, + ALL_LAYOUT, + phi::QuantizeLinearDeprecatedKernel, + float, + phi::dtype::float16) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} diff --git a/paddle/phi/kernels/impl/quantize_linear_impl.h b/paddle/phi/kernels/impl/quantize_linear_impl.h index a454023d859d8a..0ec445618f98bc 100644 --- a/paddle/phi/kernels/impl/quantize_linear_impl.h +++ b/paddle/phi/kernels/impl/quantize_linear_impl.h @@ -23,6 +23,7 @@ #include "paddle/phi/common/place.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/cast_kernel.h" +#include "paddle/phi/kernels/funcs/fake_quantize_functor.h" namespace phi { @@ -89,7 +90,7 @@ void DeQuantizeLinearImpl(const Context& dev_ctx, template void DeQuantizeLinearKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, + const paddle::optional& in_scale, const DenseTensor& zero_point, const paddle::optional& in_accum, const paddle::optional& in_state, @@ -102,6 +103,11 @@ void DeQuantizeLinearKernel(const Context& dev_ctx, DenseTensor* out_state, DenseTensor* out_accum, DenseTensor* out_scale) { + PADDLE_ENFORCE_NE(in_scale.get_ptr(), + nullptr, + phi::errors::PreconditionNotMet( + "in_scale can't be nullptr in DeQuantizeLinearKernel")); + auto scale = in_scale.get(); switch (scale.dtype()) { case phi::DataType::FLOAT64: DeQuantizeLinearImpl( @@ -124,4 +130,159 @@ void DeQuantizeLinearKernel(const Context& dev_ctx, } } +template +void QuantizeLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const DenseTensor& zero_point, + const paddle::optional& in_accum, + const paddle::optional& in_state, + int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, + DenseTensor* out, + DenseTensor* out_state, + DenseTensor* out_accum, + DenseTensor* out_scale) { + PADDLE_ENFORCE_NE(scale.get_ptr(), + nullptr, + phi::errors::PreconditionNotMet( + "in_scale can't be nullptr in DeQuantizeLinearKernel")); + auto* in = &x; + auto* in_scale = scale.get_ptr(); + dev_ctx.template Alloc(out); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + + if (quant_axis < 0) { + if (!is_test) { + // training + phi::DenseTensor tmp_scale; + tmp_scale.Resize(common::make_dim(1)); + T* cur_scale_data = dev_ctx.template Alloc(&tmp_scale); + + phi::funcs::FindAbsMaxFunctor()( + dev_ctx, in->data(), in->numel(), cur_scale_data); + + dev_ctx.template Alloc(out_state); + dev_ctx.template Alloc(out_accum); + dev_ctx.template Alloc(out_scale); + + phi::funcs::FindMovingAverageAbsMaxFunctor()(dev_ctx, + in_accum.get(), + in_state.get(), + cur_scale_data, + 0.9, + out_state, + out_accum, + out_scale); + if (only_observer) { + phi::Copy(dev_ctx, *in, dev_ctx.GetPlace(), false, out); + } else { + phi::funcs::ClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, round_type, out); + } + } else { + if (only_observer) { + phi::Copy(dev_ctx, *in, dev_ctx.GetPlace(), false, out); + } else { + phi::funcs::ClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, round_type, out); + } + } + } else { + if (!is_test) { + T* out_scale_data = dev_ctx.template Alloc(out_scale); + phi::funcs::FindChannelAbsMaxFunctor()( + dev_ctx, *in, quant_axis, out_scale_data); + if (only_observer) { + phi::Copy(dev_ctx, *in, dev_ctx.GetPlace(), false, out); + } else { + phi::funcs::ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out); + } + } else { + if (only_observer) { + phi::Copy(dev_ctx, *in, dev_ctx.GetPlace(), false, out); + } else { + phi::funcs::ChannelClipAndFakeQuantFunctor()( + dev_ctx, *in, *in_scale, bin_cnt, round_type, quant_axis, out); + } + } + } +} + +template +void QuantizeLinearDeprecatedKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& in_scale, + const DenseTensor& zero_point, + const paddle::optional& in_accum, + const paddle::optional& in_state, + int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, + DenseTensor* out, + DenseTensor* out_state, + DenseTensor* out_accum, + DenseTensor* out_scale) { + paddle::optional scale = + paddle::make_optional(in_scale); + QuantizeLinearKernel(dev_ctx, + x, + scale, + zero_point, + in_accum, + in_state, + quant_axis, + bit_length, + round_type, + is_test, + only_observer, + out, + out_state, + out_accum, + out_scale); +} + +template +void DeQuantizeLinearDeprecatedKernel( + const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& in_scale, + const DenseTensor& zero_point, + const paddle::optional& in_accum, + const paddle::optional& in_state, + int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, + DenseTensor* out, + DenseTensor* out_state, + DenseTensor* out_accum, + DenseTensor* out_scale) { + paddle::optional scale = + paddle::make_optional(in_scale); + DeQuantizeLinearKernel(dev_ctx, + x, + scale, + zero_point, + in_accum, + in_state, + quant_axis, + bit_length, + round_type, + is_test, + only_observer, + out, + out_state, + out_accum, + out_scale); +} + } // namespace phi diff --git a/paddle/phi/kernels/quantize_linear_kernel.h b/paddle/phi/kernels/quantize_linear_kernel.h index c10a67f51e6030..ea487f74999081 100644 --- a/paddle/phi/kernels/quantize_linear_kernel.h +++ b/paddle/phi/kernels/quantize_linear_kernel.h @@ -20,10 +20,27 @@ namespace phi { +template +void QuantizeLinearKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& scale, + const DenseTensor& zero_point, + const paddle::optional& in_accum, + const paddle::optional& in_state, + int quant_axis, + int bit_length, + int round_type, + bool is_test, + bool only_observer, + DenseTensor* out, + DenseTensor* out_state, + DenseTensor* out_accum, + DenseTensor* out_scale); + template void DeQuantizeLinearKernel(const Context& dev_ctx, const DenseTensor& x, - const DenseTensor& scale, + const paddle::optional& scale, const DenseTensor& zero_point, const paddle::optional& in_accum, const paddle::optional& in_state, diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 26661ea9db34e7..0ba713d5314c6d 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -455,16 +455,16 @@ backward : depthwise_conv2d_transpose_grad - op : dequantize_linear - args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int round_type = 0, bool is_test = true, bool only_observer = false, float moving_rate=0.9f) - output : Tensor(y), Tensor(out_scale), Tensor(out_accum), Tensor(out_state) + args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int round_type = 0, bool is_test = true, bool only_observer = false) + output : Tensor(y), Tensor(out_state), Tensor(out_accum), Tensor(out_scale) infer_meta : func : QuantizeLinearInferMeta - param : [x, scale, in_accum, in_state, quant_axis] + param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer] kernel : func : quantize_linear - param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer, moving_rate] + param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer] data_type : x - optional : in_accum, in_state, out_scale, out_accum, out_state + optional : scale, in_accum, in_state, out_state, out_accum, out_scale inplace : (scale -> out_scale, in_accum -> out_accum, in_state -> out_state) - op : dequantize_log @@ -1357,19 +1357,21 @@ data_type : out_grad_in inplace: (out_grad_in -> out_grad_out) +# Note: dequantize_linear and quantize_linear are supported using one op maker in fluid, the out_scale can't be used in dequantize_linear +# so ,the out_scale is optional. Currently, we can't modify the op definition of dequantize_linear/quantize_linear and it can cause incompatibility problem +# We need modify dequantize_linear/quantize_linear yaml and make it more reasonable when we abandon Fluid op. - op : quantize_linear - args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int round_type = 0, bool is_test = true, bool only_observer = false, float moving_rate=0.9f) - output : Tensor(y), Tensor(out_scale), Tensor(out_accum), Tensor(out_state) + args : (Tensor x, Tensor scale, Tensor zero_point, Tensor in_accum, Tensor in_state, int quant_axis = 0, int bit_length = 8, int round_type = 0, bool is_test = true, bool only_observer = false) + output : Tensor(y), Tensor(out_state), Tensor(out_accum), Tensor(out_scale) infer_meta : func : QuantizeLinearInferMeta - param : [x, scale, in_accum, in_state, quant_axis] + param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer] kernel : func : quantize_linear - param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer, moving_rate] + param : [x, scale, zero_point, in_accum, in_state, quant_axis, bit_length, round_type, is_test, only_observer] data_type : x - optional : in_accum, in_state, out_scale, out_accum, out_state + optional : scale, in_accum, in_state, out_state, out_accum, out_scale inplace : (scale -> out_scale, in_accum -> out_accum, in_state -> out_state) - - op : randint args : (int low, int high, IntArray shape, DataType dtype=DataType::INT64, Place place={}) output : Tensor(out)