Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

quantization aware training pass #3817

Merged
merged 92 commits into from
Jan 13, 2021
Merged
Changes from 1 commit
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
4fcb53a
init qat pass
daquexian Oct 28, 2020
c2f7c24
fix bugs
daquexian Oct 28, 2020
b20f7be
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 9, 2020
8369e5b
add calculate weight scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
1719c1e
clear batch axis of scale and zero_point
Ldpe2G Nov 4, 2020
f89ac2d
add calculate activation scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
b6dd00f
add fake quantization ops & unit tests
Ldpe2G Nov 5, 2020
8c2b8b9
add sbp signature to fake quantization op & improve code style
Ldpe2G Nov 9, 2020
0944547
imporve unit test speed
Ldpe2G Nov 10, 2020
c5e3817
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 11, 2020
a8706de
Merge remote-tracking branch 'origin/dev_add_quantization_aware_train…
daquexian Nov 11, 2020
42ca848
update pass
daquexian Nov 12, 2020
4d21070
add QatConfig
daquexian Nov 13, 2020
de11311
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 13, 2020
c00107b
format
daquexian Nov 13, 2020
d2a9409
code clean
daquexian Nov 13, 2020
5e7b278
add calculate weight scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
6a95254
clear batch axis of scale and zero_point
Ldpe2G Nov 4, 2020
b43323f
add calculate activation scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
b9bc05f
add fake quantization ops & unit tests
Ldpe2G Nov 5, 2020
e2a1550
add sbp signature to fake quantization op & improve code style
Ldpe2G Nov 9, 2020
4c4ee3d
imporve unit test speed
Ldpe2G Nov 10, 2020
7e6624c
make changes according to review comments
Ldpe2G Nov 11, 2020
5c76514
rename quantize ops following the pytorch's naming scheme
Ldpe2G Nov 13, 2020
ea4625f
change the input zero_point of fake_quantize op to optional
Ldpe2G Nov 13, 2020
ac14861
stop updating moving_min and moving_max after training iteration reac…
Ldpe2G Nov 14, 2020
3431455
Merge remote-tracking branch 'origin/dev_add_quantization_aware_train…
daquexian Nov 17, 2020
8e1507f
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 19, 2020
c50f81d
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 21, 2020
80ec414
add calculate weight scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
7a9e88f
clear batch axis of scale and zero_point
Ldpe2G Nov 4, 2020
ec1afc2
add calculate activation scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
eabe774
add fake quantization ops & unit tests
Ldpe2G Nov 5, 2020
9476248
add sbp signature to fake quantization op & improve code style
Ldpe2G Nov 9, 2020
fad0f9c
imporve unit test speed
Ldpe2G Nov 10, 2020
cb450e8
make changes according to review comments
Ldpe2G Nov 11, 2020
373a76b
rename quantize ops following the pytorch's naming scheme
Ldpe2G Nov 13, 2020
c925592
change the input zero_point of fake_quantize op to optional
Ldpe2G Nov 13, 2020
9e4e6bb
stop updating moving_min and moving_max after training iteration reac…
Ldpe2G Nov 14, 2020
9223dcf
align with latest fake quant ops
daquexian Nov 23, 2020
5576395
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 23, 2020
1b01366
optimize CHECK
daquexian Nov 23, 2020
2f108a2
add multiple devices tests && fix sbp infer error
Ldpe2G Nov 23, 2020
965b3df
fix bugs on mobilenetv2
daquexian Nov 24, 2020
f5ff864
Merge remote-tracking branch 'origin/dev_add_quantization_aware_train…
daquexian Nov 24, 2020
02d8c37
align with latest fake quant ops
daquexian Nov 24, 2020
b2e5226
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Nov 24, 2020
738792a
format
daquexian Nov 27, 2020
d1e9498
add calculate weight scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
1d0b8fb
clear batch axis of scale and zero_point
Ldpe2G Nov 4, 2020
085f855
add calculate activation scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
a4e5b03
add fake quantization ops & unit tests
Ldpe2G Nov 5, 2020
05c825b
add sbp signature to fake quantization op & improve code style
Ldpe2G Nov 9, 2020
4fbb99e
imporve unit test speed
Ldpe2G Nov 10, 2020
e8a0e7b
make changes according to review comments
Ldpe2G Nov 11, 2020
749d53d
rename quantize ops following the pytorch's naming scheme
Ldpe2G Nov 13, 2020
ea9ab58
change the input zero_point of fake_quantize op to optional
Ldpe2G Nov 13, 2020
2614823
stop updating moving_min and moving_max after training iteration reac…
Ldpe2G Nov 14, 2020
2ae0cd3
add multiple devices tests && fix sbp infer error
Ldpe2G Nov 23, 2020
d27e554
stop udpating moving max and min during the prediction mode
Ldpe2G Nov 27, 2020
c082983
imporve ReduceMaxMinPerChannel cuda kernel slightly
Ldpe2G Dec 1, 2020
6eb5ab1
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Dec 15, 2020
5c7565e
align with cfg job_conf
daquexian Dec 16, 2020
708cd78
amp_lsit -> op_list
daquexian Dec 16, 2020
dbaba52
support conv op with bias input
daquexian Dec 16, 2020
a272ebb
Merge remote-tracking branch 'origin/dev_add_quantization_aware_train…
daquexian Dec 16, 2020
8ea0063
add calculate weight scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
f9b34a8
clear batch axis of scale and zero_point
Ldpe2G Nov 4, 2020
0acf4f8
add calculate activation scale and zero_point op & unit tests
Ldpe2G Nov 4, 2020
467631d
add fake quantization ops & unit tests
Ldpe2G Nov 5, 2020
417585d
add sbp signature to fake quantization op & improve code style
Ldpe2G Nov 9, 2020
1d89982
imporve unit test speed
Ldpe2G Nov 10, 2020
656e1be
make changes according to review comments
Ldpe2G Nov 11, 2020
c2e7732
rename quantize ops following the pytorch's naming scheme
Ldpe2G Nov 13, 2020
fc130d9
change the input zero_point of fake_quantize op to optional
Ldpe2G Nov 13, 2020
3ae1e17
stop updating moving_min and moving_max after training iteration reac…
Ldpe2G Nov 14, 2020
9e02ed2
add multiple devices tests && fix sbp infer error
Ldpe2G Nov 23, 2020
503e8af
stop udpating moving max and min during the prediction mode
Ldpe2G Nov 27, 2020
9e2abd6
imporve ReduceMaxMinPerChannel cuda kernel slightly
Ldpe2G Dec 1, 2020
218f92f
change quantize_to_bit to quantization_bit
Ldpe2G Dec 16, 2020
9395b16
change quantize to quantization
Ldpe2G Dec 16, 2020
df6d9cc
format
daquexian Dec 21, 2020
8da2bbb
Merge remote-tracking branch 'origin/dev_add_quantization_aware_train…
daquexian Dec 21, 2020
45d813c
Merge remote-tracking branch 'origin/master' into quant_aware_trainin…
daquexian Dec 21, 2020
01769b0
fix bias zero point shape, add tests
daquexian Dec 21, 2020
5447627
set 'training' attr according to job desc
daquexian Dec 23, 2020
272b1a5
refine tests
daquexian Dec 23, 2020
0cfa8ad
Merge branch 'master' into quant_aware_training_dqx
daquexian Jan 13, 2021
3832433
polish
daquexian Jan 13, 2021
48061cb
reformat
daquexian Jan 13, 2021
5b4f005
Merge branch 'master' into quant_aware_training_dqx
oneflow-ci-bot Jan 13, 2021
22bf725
fix cpu test
daquexian Jan 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add calculate weight scale and zero_point op & unit tests
Ldpe2G committed Nov 17, 2020
commit 5e7b2786bd1244ceb80cbaf25f86e1396420c9e4
52 changes: 52 additions & 0 deletions oneflow/python/ops/quantize_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
from typing import Tuple, Optional
from oneflow.python.oneflow_export import oneflow_export

import oneflow as flow
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.remote_blob as remote_blob_util


@oneflow_export("nn.generate_quantize_scale_for_weight")
def generate_quantize_scale_for_weight(
weight: remote_blob_util.BlobDef,
quantize_to_bit: int = 8,
quantizer_type: str = "symmetric",
per_layer_quantization: bool = True,
name: Optional[str] = None,
) -> Tuple[remote_blob_util.BlobDef, remote_blob_util.BlobDef]:

scale, zero_point = (
flow.user_op_builder(
name
if name is not None
else id_util.UniqueStr("Generate_Quantize_Scale_For_Weight_")
)
.Op("generate_quantize_scale_for_weight")
.Input("weight", [weight])
.Output("scale")
.Output("zero_point")
.Attr("quantize_to_bit", quantize_to_bit)
.Attr("quantizer_type", quantizer_type)
.Attr("per_layer_quantization", per_layer_quantization)
.Build()
.InferAndTryRun()
.RemoteBlobList()
)

return scale, zero_point
165 changes: 165 additions & 0 deletions oneflow/python/test/ops/test_quantize_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
from collections import OrderedDict

import numpy as np
import oneflow as flow
from test_util import GenArgList, type_name_to_flow_type, type_name_to_np_type
import oneflow.typing as oft


def gen_quant_scale_per_layer_symmetric(weight, quantize_to_bit):
weight_max = np.max(np.abs(weight))
denominator = 2.0 ** (quantize_to_bit - 1) - 1
return weight_max / denominator, 0


def gen_quant_scale_per_layer_affine(weight, quantize_to_bit):
weight_max = np.max(weight)
weight_min = np.min(weight)
denominator = 2.0 ** (quantize_to_bit) - 1
scale = (weight_max - weight_min) / denominator
zero_point = -weight_min / scale
return scale, zero_point


def product(tu):
p = 1
for t in tu:
p = p * t
return p


def _check(
test_case,
weight,
scale_of,
zero_point_of,
quantize_to_bit,
quantizer_type,
per_layer_quantization,
):
if per_layer_quantization:
outer_num = 1
inner_num = product(weight.shape[0:])
else:
outer_num = weight.shape[0]
inner_num = product(weight.shape[1:])

scale_np = np.zeros((outer_num,))
zero_point_np = np.zeros((outer_num,))

weight_flatten = weight.flatten()

if quantizer_type == "symmetric":
for c in range(outer_num):
scale_np[c], zero_point_np[c] = gen_quant_scale_per_layer_symmetric(
weight_flatten[c * inner_num : (c + 1) * inner_num], quantize_to_bit
)
else: # "affine"
for c in range(outer_num):
scale_np[c], zero_point_np[c] = gen_quant_scale_per_layer_affine(
weight_flatten[c * inner_num : (c + 1) * inner_num], quantize_to_bit
)

# print(weight)
print(scale_of, zero_point_of)
print(scale_np, zero_point_np)

test_case.assertTrue(np.allclose(scale_of, scale_np, rtol=1e-3))
test_case.assertTrue(
np.allclose(
zero_point_of.astype(np.int), zero_point_np.astype(np.int), rtol=1e-3
)
)


def _run_test(
test_case,
device_type,
dtype,
weight_shape,
quantize_to_bit,
quantizer_type,
per_layer_quantization,
):
assert device_type in ["gpu", "cpu"]
flow.clear_default_session()
flow.config.enable_debug_mode(True)

@flow.global_function(type="predict", function_config=flow.FunctionConfig())
def QuantizeJob(
weight: oft.Numpy.Placeholder(weight_shape, dtype=type_name_to_flow_type[dtype])
):
with flow.scope.placement(device_type, "0:0"):
scale, zero_point = flow.nn.generate_quantize_scale_for_weight(
weight, quantize_to_bit, quantizer_type, per_layer_quantization
)
return scale, zero_point

check_point = flow.train.CheckPoint()
check_point.init()
weight = (np.random.random(weight_shape) - 1).astype(type_name_to_np_type[dtype])
scale, zero_point = QuantizeJob(weight).get()

_check(
test_case,
weight,
scale.numpy(),
zero_point.numpy(),
quantize_to_bit,
quantizer_type,
per_layer_quantization,
)


# @flow.unittest.skip_unless_1n1d()
# class TestGenQuantScaleForWeight(flow.unittest.TestCase):
# def test_gen_quant_scale_for_weight(test_case):
# arg_dict = OrderedDict()
# arg_dict["test_case"] = [test_case]
# arg_dict["device_type"] = ["cpu"] # ["gpu", "cpu"]
# arg_dict["dtype"] = ["float32", "double"]
# arg_dict["weight_shape"] = [(10, 10, 20, 20)]
# arg_dict["quantize_to_bit"] = [8, 7, 6, 5, 4, 3, 2]
# arg_dict["quantizer_type"] = ["symmetric", "affine"]
# arg_dict["per_layer_quantization"] = [True, False]

# for arg in GenArgList(arg_dict):
# _run_test(*arg)


@flow.unittest.skip_unless_1n1d()
class TestGenQuantScaleForWeight(flow.unittest.TestCase):
def test_gen_quant_scale_for_weight(test_case):
arg_dict = OrderedDict()
arg_dict["test_case"] = [test_case]
arg_dict["device_type"] = ["gpu", "cpu"]
arg_dict["dtype"] = ["float32", "double"]
arg_dict["weight_shape"] = [(10, 10, 20, 20), (10, 3, 3, 3), (9, 10, 20, 20)]
arg_dict["quantize_to_bit"] = [8, 7, 6, 5, 4, 3, 2]
arg_dict["quantizer_type"] = ["symmetric", "affine"]
arg_dict["per_layer_quantization"] = [True, False]

for arg in GenArgList(arg_dict):
print(arg)
_run_test(*arg)


if __name__ == "__main__":
unittest.main()
107 changes: 107 additions & 0 deletions oneflow/user/kernels/generate_quantize_scale_for_weight_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"

#include <algorithm>

namespace oneflow {

template<typename T>
void gen_quant_scale_per_layer_symmetric(const int32_t quantize_to_bit, const int64_t num_elements,
const T *weight_ptr, T *weight_scale, T *zero_point) {
T weight_max = *std::max_element(weight_ptr, weight_ptr + num_elements);
T weight_min = *std::min_element(weight_ptr, weight_ptr + num_elements);

weight_max = std::max(std::abs(weight_max), std::abs(weight_min));

T denominator = T(pow(2.0, quantize_to_bit - 1)) - 1;

weight_scale[0] = weight_max / denominator;
zero_point[0] = 0;
}

template<typename T>
void gen_quant_scale_per_layer_affine(const int32_t quantize_to_bit, const int64_t num_elements,
const T *weight_ptr, T *weight_scale, T *zero_point) {
T weight_max = *std::max_element(weight_ptr, weight_ptr + num_elements);
T weight_min = *std::min_element(weight_ptr, weight_ptr + num_elements);

T denominator = T(pow(2.0, quantize_to_bit)) - 1;

weight_scale[0] = (weight_max - weight_min) / denominator;
zero_point[0] = -weight_min / weight_scale[0];
}

template<typename T>
class CpuGenerateQuantizeScaleForWeightKernel final : public user_op::OpKernel {
public:
CpuGenerateQuantizeScaleForWeightKernel() = default;
~CpuGenerateQuantizeScaleForWeightKernel() = default;

private:
void Compute(user_op::KernelComputeContext *ctx) const override {
const user_op::Tensor *weight = ctx->Tensor4ArgNameAndIndex("weight", 0);
user_op::Tensor *weight_scale = ctx->Tensor4ArgNameAndIndex("scale", 0);
user_op::Tensor *zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0);

const std::string quantizer_type = ctx->Attr<std::string>("quantizer_type");
const int32_t quantize_to_bit = ctx->Attr<int32_t>("quantize_to_bit");
const bool per_layer_quantization = ctx->Attr<bool>("per_layer_quantization");

const T *weight_ptr = weight->dptr<T>();
T *weight_scale_ptr = weight_scale->mut_dptr<T>();
T *zero_point_ptr = zero_point->mut_dptr<T>();

// NOTE(Liang Depeng): default is per layer quantization
int64_t outer_num = 1;
int64_t inner_num = weight->shape().elem_cnt();
if (!per_layer_quantization) { // per-channel quantization
outer_num = weight->shape().At(0);
inner_num = weight->shape().Count(1);
}

if (quantizer_type == "symmetric") {
FOR_RANGE(int64_t, c, 0, outer_num) {
gen_quant_scale_per_layer_symmetric(quantize_to_bit, inner_num, weight_ptr,
weight_scale_ptr, zero_point_ptr);
weight_ptr += inner_num;
weight_scale_ptr += 1;
zero_point_ptr += 1;
}
} else { // quantizer_type == "affine"
FOR_RANGE(int64_t, c, 0, outer_num) {
gen_quant_scale_per_layer_affine(quantize_to_bit, inner_num, weight_ptr, weight_scale_ptr,
zero_point_ptr);
weight_ptr += inner_num;
weight_scale_ptr += 1;
zero_point_ptr += 1;
}
}
};

bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_GENERATE_QUANTIZE_SCALE_FOR_WEIGHT_KERNEL(dtype) \
REGISTER_USER_KERNEL("generate_quantize_scale_for_weight") \
.SetCreateFn<CpuGenerateQuantizeScaleForWeightKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == DeviceType::kCPU) \
& (user_op::HobDataType("weight", 0) == GetDataType<dtype>::value))

REGISTER_GENERATE_QUANTIZE_SCALE_FOR_WEIGHT_KERNEL(float);
REGISTER_GENERATE_QUANTIZE_SCALE_FOR_WEIGHT_KERNEL(double);

} // namespace oneflow
217 changes: 217 additions & 0 deletions oneflow/user/kernels/generate_quantize_scale_for_weight_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/kernel_util.cuh"

#include <float.h>

namespace oneflow {

namespace {

// NOTE(Liang Depeng): refer to
// https://stackoverflow.com/questions/17371275/implementing-max-reduce-in-cuda
template<typename T>
__global__ void ReduceMaxMinPerLayer(const T *input_ptr, const int64_t elements, T *max_ptr,
T *min_ptr) {
extern __shared__ unsigned char shared_max_min_memory[];
T *shared_max = reinterpret_cast<T *>(shared_max_min_memory);
T *shared_min = shared_max + blockDim.x;

int64_t tid = threadIdx.x;
int64_t gid = (blockDim.x * blockIdx.x) + tid;
shared_max[tid] = -FLT_MAX;
shared_min[tid] = -FLT_MAX;

while (gid < elements) {
shared_max[tid] = max(shared_max[tid], input_ptr[gid]);
shared_min[tid] = max(shared_min[tid], -input_ptr[gid]);
gid += gridDim.x * blockDim.x;
}
__syncthreads();
gid = (blockDim.x * blockIdx.x) + tid;
for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s && gid < elements) {
shared_max[tid] = max(shared_max[tid], shared_max[tid + s]);
shared_min[tid] = max(shared_min[tid], shared_min[tid + s]);
}
__syncthreads();
}

if (tid == 0) {
gpu_atomic_max(max_ptr, shared_max[0]);
gpu_atomic_max(min_ptr, shared_min[0]);
}
}

template<typename T>
__global__ void ReduceMaxMinPerChannel(const T *input_ptr, const int64_t elements,
const int64_t num_channels, const int64_t panel_size,
T *max_ptr, T *min_ptr) {
extern __shared__ unsigned char shared_max_min_memory[];
T *shared_max = reinterpret_cast<T *>(shared_max_min_memory);
T *shared_min = shared_max + blockDim.x;

int64_t cur_channel = blockIdx.x;
int64_t tid = threadIdx.x;

while (cur_channel < num_channels) {
shared_max[tid] = -FLT_MAX;
shared_min[tid] = -FLT_MAX;

int64_t index = (panel_size * cur_channel) + tid;
int64_t end = panel_size * (cur_channel + 1);

while (index < end && index < elements) {
shared_max[tid] = max(shared_max[tid], input_ptr[index]);
shared_min[tid] = max(shared_min[tid], -input_ptr[index]);
index += blockDim.x;
}
__syncthreads();

for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
if (tid < s) {
shared_max[tid] = max(shared_max[tid], shared_max[tid + s]);
shared_min[tid] = max(shared_min[tid], shared_min[tid + s]);
}
__syncthreads();
}

if (tid == 0) {
gpu_atomic_max(&max_ptr[cur_channel], shared_max[0]);
gpu_atomic_max(&min_ptr[cur_channel], shared_min[0]);
}

__syncthreads();
cur_channel += gridDim.x;
}
}

template<typename T>
__global__ void InitMaxMin(const int64_t elements, T *max_ptr, T *min_ptr) {
int64_t tid = threadIdx.x;
int64_t gid = (blockDim.x * blockIdx.x) + tid;

if (gid < elements) {
max_ptr[gid] = -FLT_MAX;
min_ptr[gid] = -FLT_MAX;
}
}

template<typename T>
__global__ void CalScaleZeroPointSymmetric(const T *max_ptr, const T *min_ptr,
const int64_t elements, const double quantize_to_bit,
T *scale, T *zero_point) {
int64_t tid = threadIdx.x;
int64_t gid = (blockDim.x * blockIdx.x) + tid;

if (gid < elements) {
T weight_max = max(fabs(max_ptr[gid]), fabs(min_ptr[gid]));
T denominator = T(pow(2.0, quantize_to_bit - 1)) - 1;
scale[gid] = weight_max / denominator;
zero_point[gid] = 0;
}
}

template<typename T>
__global__ void CalScaleZeroPointAffine(const T *max_ptr, const T *min_ptr, const int64_t elements,
const double quantize_to_bit, T *scale, T *zero_point) {
int64_t tid = threadIdx.x;
int64_t gid = (blockDim.x * blockIdx.x) + tid;

if (gid < elements) {
T denominator = T(pow(2.0, quantize_to_bit)) - 1;
T min = -min_ptr[gid];
T s = (max_ptr[gid] - min) / denominator;
scale[gid] = s;
zero_point[gid] = -min / s;
}
}

} // namespace

#define LAUNCH_CUDA_KERNEL(func, device_ctx_ptr, thread_num, shared_mem_size, ...) \
func<<<SMBlocksNum4ThreadsNum(thread_num), kCudaThreadsNumPerBlock, shared_mem_size, \
(device_ctx_ptr)->cuda_stream()>>>(__VA_ARGS__)
template<typename T>
class GpuGenerateQuantizeScaleForWeightKernel final : public user_op::OpKernel {
public:
GpuGenerateQuantizeScaleForWeightKernel() = default;
~GpuGenerateQuantizeScaleForWeightKernel() = default;
private:
void Compute(user_op::KernelComputeContext *ctx) const override {
const user_op::Tensor *weight = ctx->Tensor4ArgNameAndIndex("weight", 0);
user_op::Tensor *weight_scale = ctx->Tensor4ArgNameAndIndex("scale", 0);
user_op::Tensor *zero_point = ctx->Tensor4ArgNameAndIndex("zero_point", 0);
user_op::Tensor *tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
const std::string quantizer_type = ctx->Attr<std::string>("quantizer_type");
const int32_t quantize_to_bit = ctx->Attr<int32_t>("quantize_to_bit");
const bool per_layer_quantization = ctx->Attr<bool>("per_layer_quantization");
int64_t elements = weight->shape().elem_cnt();
int64_t channel = weight_scale->shape().At(0);
int64_t panel_size = elements / channel;
T *max_ptr = tmp_buffer->mut_dptr<T>();
T *min_ptr = max_ptr + channel;
LAUNCH_CUDA_KERNEL((InitMaxMin<T>), ctx->device_ctx(), channel, 0, channel, max_ptr, min_ptr);
if (per_layer_quantization) {
LAUNCH_CUDA_KERNEL((ReduceMaxMinPerLayer<T>), ctx->device_ctx(), elements,
kCudaThreadsNumPerBlock * 2 * sizeof(T), weight->dptr<T>(), elements,
max_ptr, min_ptr);
} else { // per-channel quantization
LAUNCH_CUDA_KERNEL((ReduceMaxMinPerChannel<T>), ctx->device_ctx(), channel,
kCudaThreadsNumPerBlock * 2 * sizeof(T), weight->dptr<T>(), elements,
channel, panel_size, max_ptr, min_ptr);
}
if (quantizer_type == "symmetric") {
LAUNCH_CUDA_KERNEL((CalScaleZeroPointSymmetric<T>), ctx->device_ctx(), channel, 0, max_ptr,
min_ptr, channel, static_cast<double>(quantize_to_bit),
weight_scale->mut_dptr<T>(), zero_point->mut_dptr<T>());
} else { // quantizer_type == "affine"
LAUNCH_CUDA_KERNEL((CalScaleZeroPointAffine<T>), ctx->device_ctx(), channel, 0, max_ptr,
min_ptr, channel, static_cast<double>(quantize_to_bit),
weight_scale->mut_dptr<T>(), zero_point->mut_dptr<T>());
}
};
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_GENERATE_QUANTIZE_SCALE_FOR_WEIGHT_KERNEL(dtype) \
REGISTER_USER_KERNEL("generate_quantize_scale_for_weight") \
.SetCreateFn<GpuGenerateQuantizeScaleForWeightKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == DeviceType::kGPU) \
& (user_op::HobDataType("weight", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext *ctx) -> size_t { \
size_t tmp_buffer_size = 1; \
Shape *weight_shape = ctx->Shape4ArgNameAndIndex("weight", 0); \
if (ctx->Attr<bool>("per_layer_quantization") == false) { \
tmp_buffer_size = weight_shape->At(0); \
} \
return 2 * tmp_buffer_size * sizeof(dtype); \
})
REGISTER_GENERATE_QUANTIZE_SCALE_FOR_WEIGHT_KERNEL(float);
REGISTER_GENERATE_QUANTIZE_SCALE_FOR_WEIGHT_KERNEL(double);
} // namespace oneflow
74 changes: 74 additions & 0 deletions oneflow/user/ops/generate_quantize_scale_for_weight_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"

namespace oneflow {

namespace {

REGISTER_USER_OP("generate_quantize_scale_for_weight")
.Input("weight")
.Output("scale")
.Output("zero_point")
// NOTE(Liang Depeng): quantize from float32 to "quantize_to_bit" bit signed or unsigned integer
.Attr<int32_t>("quantize_to_bit", 8)
// NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer
.Attr<std::string>("quantizer_type", "symmetric")
// NOTE(Liang Depeng): "true" or "false": per-layer or per-channel quantization
.Attr<bool>("per_layer_quantization", true)
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* weight_shape = ctx->Shape4ArgNameAndIndex("weight", 0);
// NOTE(Liang Depeng): only support weights for 2D convolution and fully-connected layers.
// And assume weight shape is (Cout, Cin, K, K) or (Cout, Cin) for 2D
// convolution or fully-connected.
CHECK_OR_RETURN(weight_shape->NumAxes() == 4 || weight_shape->NumAxes() == 2);

if (ctx->Attr<bool>("per_layer_quantization") == true) {
*ctx->Shape4ArgNameAndIndex("scale", 0) = Shape({1});
*ctx->Shape4ArgNameAndIndex("zero_point", 0) = Shape({1});
} else {
*ctx->Shape4ArgNameAndIndex("scale", 0) = Shape({weight_shape->At(0)});
*ctx->Shape4ArgNameAndIndex("zero_point", 0) = Shape({weight_shape->At(0)});
}

*ctx->Dtype4ArgNameAndIndex("scale", 0) = *ctx->Dtype4ArgNameAndIndex("weight", 0);
*ctx->Dtype4ArgNameAndIndex("zero_point", 0) = *ctx->Dtype4ArgNameAndIndex("weight", 0);
return Maybe<void>::Ok();
})
.SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
const user_op::UserOpConfWrapper&) {
user_op::InputArgModifier* weight = GetInputArgModifierFn("weight", 0);
weight->set_requires_grad(false);
})
.SetBatchAxisInferFn(user_op::BatchAxisInferFnUtil::DefaultAsFirstHasValueInput)
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
// TODO(Liang Depeng): refer to reduce_max op
return Maybe<void>::Ok();
})
.SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def,
const user_op::UserOpConfWrapper& op_conf) -> Maybe<void> {
int32_t quantize_to_bit = op_conf.attr<int32_t>("quantize_to_bit");
CHECK_GT_OR_RETURN(quantize_to_bit, 0);
CHECK_LE_OR_RETURN(quantize_to_bit, 8);

std::string quantizer_type = op_conf.attr<std::string>("quantizer_type");
CHECK_OR_RETURN(quantizer_type == "symmetric" || quantizer_type == "affine");
return Maybe<void>::Ok();
});

} // namespace

} // namespace oneflow