Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 0 additions & 114 deletions paddle/fluid/operators/fake_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,37 +285,6 @@ struct ChannelClipFakeQuantDequantFunctor<phi::CPUContext, T> {
}
};

template struct ChannelClipFakeQuantDequantFunctor<phi::CPUContext, float>;
template <typename T>
struct FindRangeAbsMaxFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext &ctx,
const phi::DenseTensor &cur_scale,
const phi::DenseTensor &last_scale,
const phi::DenseTensor &iter,
const int window_size,
phi::DenseTensor *scales_arr,
phi::DenseTensor *out_scale) {
T *scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
int64_t it = iter.data<int64_t>()[0];
int idx = static_cast<int>(it % window_size);
T removed = scale_arr[idx];
T cur = cur_scale.data<T>()[0];
scale_arr[idx] = cur;

T max = last_scale.data<T>()[0];
if (max < cur) {
max = cur;
} else if (fabs(removed - max) < 1e-6) {
int size = static_cast<int>((it > window_size) ? window_size : it);
phi::funcs::FindAbsMaxFunctor<phi::CPUContext, T>()(
ctx, scale_arr, size, &max);
}
out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
}
};

template struct FindRangeAbsMaxFunctor<phi::CPUContext, float>;

class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantOrWithDequantAbsMaxOp(const std::string &type,
Expand Down Expand Up @@ -539,77 +508,6 @@ In above three formulas, the range value of c is as follow:
)DOC");
}
};

class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeRangeAbsMaxOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(
ctx->HasOutput("Out"), "Output", "Out", "FakeQuantizeRangeAbsMax");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"),
"Output",
"OutScale",
"FakeQuantizeRangeAbsMax");
if (ctx->HasOutput("OutScales")) {
int window_size = ctx->Attrs().Get<int>("window_size");
ctx->SetOutputDim("OutScales", {window_size});
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}

protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context().GetPlace());
}
};

class FakeQuantizeRangeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InScale", "Last scale.");
AddInput("Iter", "Global step iteration.").AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScale", " Current scale");
AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
AddAttr<int>("window_size", "(int, default 10000) window range size.")
.SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
.SetDefault(8)
.AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
phi::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator is used in static quantization.

$$scale = max(max(abs(x)), history_abs_max)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$

)DOC");
}
};

class FakeQuantOrWithDequantMovingAverageAbsMaxOp
: public framework::OperatorWithKernel {
public:
Expand Down Expand Up @@ -820,18 +718,6 @@ PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel,
float) {}

REGISTER_OPERATOR(
fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
PD_REGISTER_STRUCT_KERNEL(fake_quantize_range_abs_max,
CPU,
ALL_LAYOUT,
ops::FakeQuantizeRangeAbsMaxKernel,
float) {}

REGISTER_OPERATOR(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
Expand Down
6 changes: 0 additions & 6 deletions paddle/fluid/operators/fake_quantize_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel,
float,
float16) {}
PD_REGISTER_STRUCT_KERNEL(fake_quantize_range_abs_max,
GPU,
ALL_LAYOUT,
ops::FakeQuantizeRangeAbsMaxKernel,
float,
float16) {}
PD_REGISTER_STRUCT_KERNEL(moving_average_abs_max_scale,
GPU,
ALL_LAYOUT,
Expand Down
78 changes: 0 additions & 78 deletions paddle/fluid/operators/fake_quantize_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,84 +418,6 @@ struct ChannelClipAndFakeQuantFunctor<phi::GPUContext, T> {

template struct ChannelClipAndFakeQuantFunctor<phi::GPUContext, float>;

template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T *cur_scale,
const T *last_scale,
const int64_t *iter,
const int window_size,
T *scale_arr,
T *out_scale,
int *need_find_max,
int *out_size) {
int it = iter[0];
int idx = it % window_size;
T removed = scale_arr[idx];
T cur = cur_scale[0];
scale_arr[idx] = cur;
T max = last_scale[0];
out_scale[0] = max < cur ? cur : max;
if (fabs(static_cast<typename QuantizeDataType<T>::type>(removed - max)) <
1e-6) {
need_find_max[0] = 1;
out_size[0] = it > window_size ? window_size : it;
} else {
need_find_max[0] = 0;
}
}

template <typename T>
struct FindRangeAbsMaxFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext &ctx,
const phi::DenseTensor &cur_scale,
const phi::DenseTensor &last_scale,
const phi::DenseTensor &iter,
const int window_size,
phi::DenseTensor *scales_arr,
phi::DenseTensor *out_scale) {
const auto gpu_place = ctx.GetPlace();

T *scale_arr = scales_arr->mutable_data<T>(gpu_place);
T *out_scale_data = out_scale->mutable_data<T>(gpu_place);

phi::DenseTensor need_find_max, out_size;
int *find_max = need_find_max.mutable_data<int>({1}, gpu_place);
int *out_size_data = out_size.mutable_data<int>({1}, gpu_place);

FindRangeAbsMaxAndFillArray<T>
<<<1, 1, 0, ctx.stream()>>>(cur_scale.data<T>(),
last_scale.data<T>(),
iter.data<int64_t>(),
window_size,
scale_arr,
out_scale_data,
find_max,
out_size_data);

int g_find_max;
memory::Copy(platform::CPUPlace(),
&g_find_max,
gpu_place,
find_max,
sizeof(int),
ctx.stream());
ctx.Wait();
if (g_find_max) {
int len;
memory::Copy(platform::CPUPlace(),
&len,
gpu_place,
out_size_data,
sizeof(int),
ctx.stream());
ctx.Wait();
phi::funcs::FindAbsMaxFunctor<phi::GPUContext, T>()(
ctx, scale_arr, len, out_scale_data);
}
}
};

template struct FindRangeAbsMaxFunctor<phi::GPUContext, float>;

// ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
Expand Down
58 changes: 0 additions & 58 deletions paddle/fluid/operators/fake_quantize_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,6 @@ struct ClipAndFakeQuantDequantFunctor {
phi::DenseTensor *out);
};

template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext &ctx,
const phi::DenseTensor &cur_scale,
const phi::DenseTensor &last_scale,
const phi::DenseTensor &iter,
const int window_size,
phi::DenseTensor *scales_arr,
phi::DenseTensor *out_scale);
};

template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext &ctx,
Expand Down Expand Up @@ -176,53 +165,6 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
}
};

template <typename T, typename DeviceContext>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *in = context.Input<phi::DenseTensor>("X");
auto *in_scale = context.Input<phi::DenseTensor>("InScale");

auto *out = context.Output<phi::DenseTensor>("Out");
out->mutable_data<T>(context.GetPlace());

bool is_test = context.Attr<bool>("is_test");
int bit_length = context.Attr<int>("bit_length");
int round_type = context.Attr<int>("round_type");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto &dev_ctx = context.template device_context<DeviceContext>();

// testing
if (is_test) {
phi::funcs::ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
return;
}

// training
auto *out_scale = context.Output<phi::DenseTensor>("OutScale");
auto *out_scales = context.Output<phi::DenseTensor>("OutScales");
auto *iter = context.Input<phi::DenseTensor>("Iter");

int window_size = context.Attr<int>("window_size");
out_scale->mutable_data<T>(context.GetPlace());

phi::DenseTensor cur_scale;
T *cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
phi::funcs::FindAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
cur_scale,
*in_scale,
*iter,
window_size,
out_scales,
out_scale);
phi::funcs::ClipAndFakeQuantFunctor<DeviceContext, T>()(
dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
}
};

template <typename T, typename DeviceContext>
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
public:
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1126,8 +1126,14 @@
out_accum : OutAccum

- op : fake_quantize_range_abs_max
extra :
attrs : [int round_type = 1]
inputs :
x : X
in_scale : InScale
iter : Iter
outputs :
out : Out
out_scale : OutScale
out_scales : OutScales

- op : fc
inputs :
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,16 @@
data_type : x
optional : in_accum, in_state, out_state, out_accum

- op : fake_quantize_range_abs_max
args : (Tensor x, Tensor in_scale, Tensor iter, int window_size = 10000, int bit_length = 8, bool is_test = false, int round_type = 1)
output : Tensor(out), Tensor(out_scale), Tensor(out_scales)
infer_meta :
func : FakeQuantizeRangeAbsMaxInferMeta
kernel :
func : fake_quantize_range_abs_max
data_type : x
optional : iter, out_scales

- op : fft_c2c
args : (Tensor x, int64_t[] axes, str normalization, bool forward)
output : Tensor
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,31 @@ void DpsgdInferMeta(const MetaTensor& param,
grad.dims()));
param_out->set_dims(param_dims);
}

void FakeQuantizeRangeAbsMaxInferMeta(const MetaTensor& x,
const MetaTensor& in_scale,
const MetaTensor& iter,
int window_size,
int bit_length,
bool is_test,
int round_type,
MetaTensor* out,
MetaTensor* out_scale,
MetaTensor* out_scales) {
PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
true,
phi::errors::InvalidArgument(
"'bit_length' should be between 1 and 16, but "
"the received is %d",
bit_length));
if (out_scales) {
out_scales->set_dims({window_size});
}
out->set_dims(x.dims());
out_scale->set_dims({1});
out->share_lod(x);
}

void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
Expand Down
11 changes: 11 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ void DpsgdInferMeta(const MetaTensor& param,
int size,
MetaTensor* param_out);

void FakeQuantizeRangeAbsMaxInferMeta(const MetaTensor& x,
const MetaTensor& in_scale,
const MetaTensor& iter,
int window_size,
int bit_length,
bool is_test,
int round_type,
MetaTensor* out,
MetaTensor* out_scale,
MetaTensor* out_scales);

void FlashAttnInferMeta(const MetaTensor& q,
const MetaTensor& k,
const MetaTensor& v,
Expand Down
Loading