Skip to content

Commit

Permalink
Dynamic quantization for bias. (pytorch#26057)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#26057

bias is now unquantized (i.e. floating type) for qconv and qlinear. It is dynamically quantized by fbgemm.

TODO: Add some performance numbers.

Tests:

test:quantization
```
Summary (total time 8.41s):
  PASS: 24
  FAIL: 0
  SKIP: 0
  FATAL: 0
  TIMEOUT: 0More details at https://our.intern.facebook.com/intern/buck/build/74d5f6f7-55c9-4350-a618-2013042fffd8

  OMIT: 0
```

test:quantized
```
Summary (total time 13.21s):
  PASS: 43
  FAIL: 0
  SKIP: 5
    caffe2/test:quantized - test_qnnpack_maxpool2d (test_quantized.TestQNNPackOps)
    caffe2/test:quantized - test_compare_tensor_scalar (test_quantized.TestComparatorOps)
    caffe2/test:quantized - test_qnnpack_linear (test_quantized.TestQNNPackOps)
    caffe2/test:quantized - test_qnnpack_relu (test_quantized.TestQNNPackOps)
    caffe2/test:quantized - test_qnnpack_add (test_quantized.TestQNNPackOps)
  FATAL: 0
  TIMEOUT: 0
  OMIT: 0
```
ghstack-source-id: 90166254

Test Plan:
buck test mode/dev caffe2/test:quantization

buck test mode/dev caffe2/test:quantized

Differential Revision: D17328028

fbshipit-source-id: d4a163d730d0f4a03e8e0faf7420710cf36eec09
  • Loading branch information
dskhudia authored and facebook-github-bot committed Sep 16, 2019
1 parent 4a947b6 commit 2b52c1d
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 156 deletions.
89 changes: 42 additions & 47 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,53 +133,41 @@ class QConv2dInt8 final : public c10::OperatorKernel {
float act_scale = act.q_scale();
int32_t act_zero_point = act.q_zero_point();

const int32_t* bias_ptr = nullptr;
at::Tensor qbias;
const float* bias_ptr = nullptr;
at::Tensor bias;
if (pack_ptr.bias.has_value()) {
at::Tensor bias = pack_ptr.bias.value();
// Temporary: Quantize bias
if (pack_ptr.q_scheme == kPerTensorAffine) {
qbias = at::quantize_linear(
at::dequantize(bias), pack_ptr.w_scale[0] * act_scale, 0, kQInt32);
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
std::array<int64_t, 1> arr{0};
IntArrayRef axis(arr.data(), 1);
at::Tensor bias_scale = at::ones({K}, at::dtype(at::kDouble));
at::Tensor bias_zp = at::zeros({K}, at::dtype(at::kLong));
for (int i = 0; i < K; ++i) {
bias_scale.data_ptr<double>()[i] = pack_ptr.w_scale[i] * act_scale;
}
qbias = quantize_linear_per_channel_cpu(
at::dequantize(bias), bias_scale, bias_zp, axis, kQInt32);
} else {
qbias = bias;
TORCH_CHECK(false, "Unsupported quantization scheme.")
}
TORCH_CHECK(qbias.dim() == 1, "bias should be a vector (1D Tensor)");
bias = pack_ptr.bias.value();
TORCH_CHECK(
bias.dtype() == at::kFloat,
"[QConv2D] The 'bias' tensor must have 'torch.float' dtype");
bias = bias.contiguous();
TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)");
TORCH_CHECK(
qbias.size(0) == K,
bias.size(0) == K,
"bias should have K elements: " + std::to_string(K));
auto bias_contig = qbias.contiguous();
bias_ptr =
reinterpret_cast<int32_t*>(bias_contig.data_ptr<c10::qint32>());
bias_ptr = bias.data_ptr<float>();
}

std::vector<float> output_multiplier_float(1, 0.0);
std::vector<float> act_times_w_scale(1, 1.0);
TORCH_CHECK(
pack_ptr.w_scale.size() == pack_ptr.w_zp.size(),
"Weight scales and zero points vectors should have the same size.");
// quantization scheme is PerTensorAffine if the number of scales is 1 and
// it's kPerChannelAffine if the number of scales is equal to K (output
// channels)
if (pack_ptr.w_scale.size() == 1) {

if (pack_ptr.q_scheme == kPerTensorAffine) {
act_times_w_scale[0] = (act_scale * pack_ptr.w_scale[0]);
output_multiplier_float[0] =
(act_scale * pack_ptr.w_scale[0]) / static_cast<float>(output_scale);
} else if (pack_ptr.w_scale.size() == K) {
act_times_w_scale[0] / static_cast<float>(output_scale);
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
output_multiplier_float.resize(K, 0.0);
act_times_w_scale.resize(K, 1.0);
for (int i = 0; i < K; ++i) {
output_multiplier_float[i] = (act_scale * pack_ptr.w_scale[i]) /
static_cast<float>(output_scale);
act_times_w_scale[i] = (act_scale * pack_ptr.w_scale[i]);
output_multiplier_float[i] =
act_times_w_scale[i] / static_cast<float>(output_scale);
}
} else {
TORCH_CHECK(false, "[QConv2D] Unknown quantization scheme");
}

auto outShape =
Expand All @@ -194,17 +182,22 @@ class QConv2dInt8 final : public c10::OperatorKernel {
auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));

if (pack_ptr.q_scheme == kPerTensorAffine) {
fbgemm::ReQuantizeOutput<ReluFused> outputProcObj(
NoOpObj,
output_multiplier_float.data(),
output_zero_point,
act_zero_point,
pack_ptr.w_zp.data(),
nullptr, /* row offset buffer */
col_offsets.data(),
bias_ptr,
K,
groups);
fbgemm::ReQuantizeOutput<
ReluFused,
fbgemm::QuantizationGranularity::TENSOR,
float>
outputProcObj(
NoOpObj,
output_multiplier_float.data(),
output_zero_point,
act_zero_point,
pack_ptr.w_zp.data(),
nullptr, /* row offset buffer */
col_offsets.data(),
bias_ptr,
K,
groups,
act_times_w_scale.data());
fbgemm::fbgemmConv(
conv_p,
act_ptr,
Expand All @@ -218,7 +211,8 @@ class QConv2dInt8 final : public c10::OperatorKernel {
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
fbgemm::ReQuantizeOutput<
ReluFused,
fbgemm::QuantizationGranularity::OUT_CHANNEL>
fbgemm::QuantizationGranularity::OUT_CHANNEL,
float>
outputProcObj(
NoOpObj,
output_multiplier_float.data(),
Expand All @@ -229,7 +223,8 @@ class QConv2dInt8 final : public c10::OperatorKernel {
col_offsets.data(),
bias_ptr,
K,
groups);
groups,
act_times_w_scale.data());

fbgemm::fbgemmConv(
conv_p,
Expand Down
102 changes: 44 additions & 58 deletions aten/src/ATen/native/quantized/cpu/qlinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,23 @@ class QLinearInt8 final : public torch::OperatorKernel {
int32_t input_zero_point_int32 = input.q_zero_point();

std::vector<float> output_multiplier_float(1, 0.0);
std::vector<float> act_times_w_scale(1, 0.0);
TORCH_CHECK(
pack_ptr.w_scale.size() == pack_ptr.w_zp.size(),
"Weight scales and zero points vectors should have the same size.");
// quantization scheme is PerTensorAffine if the number of scales is
// 1 and it's kPerChannelAffine if the number of scales is equal to
// N (output channels)
if (pack_ptr.q_scheme == kPerTensorAffine) {
// Process the per tensor quantization.
output_multiplier_float[0] = (input_scale_float * pack_ptr.w_scale[0]) /
static_cast<float>(output_scale);
act_times_w_scale[0] = (input_scale_float * pack_ptr.w_scale[0]);
output_multiplier_float[0] =
act_times_w_scale[0] / static_cast<float>(output_scale);
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
// Process the per channel quantization.
output_multiplier_float.resize(N, 0.0);
act_times_w_scale.resize(N, 1.0f);
for (int i = 0; i < N; ++i) {
output_multiplier_float[i] = (input_scale_float * pack_ptr.w_scale[i]) /
static_cast<float>(output_scale);
act_times_w_scale[i] = (input_scale_float * pack_ptr.w_scale[i]);
output_multiplier_float[i] =
act_times_w_scale[i] / static_cast<float>(output_scale);
}
}
int32_t output_zero_point_int32 = static_cast<int32_t>(output_zero_point);
Expand Down Expand Up @@ -106,40 +107,16 @@ class QLinearInt8 final : public torch::OperatorKernel {
// This is the end of the pipeline, pass the resulting matrix through.
fbgemm::DoNothing<> doNothingObj{};

const int32_t* bias_ptr = nullptr;
at::Tensor qbias;
const float* bias_ptr = nullptr;
at::Tensor bias;
if (pack_ptr.bias.has_value()) {
at::Tensor bias = pack_ptr.bias.value();
// Temporary: Quantize bias
if (pack_ptr.q_scheme == kPerTensorAffine) {
qbias = at::quantize_linear(
at::dequantize(bias),
pack_ptr.w_scale[0] * input_scale_float,
0,
kQInt32);
} else if (pack_ptr.q_scheme == kPerChannelAffine) {
std::array<int64_t, 1> arr{0};
IntArrayRef axis(arr.data(), 1);
at::Tensor bias_scale = at::ones({N}, at::dtype(at::kDouble));
at::Tensor bias_zp = at::zeros({N}, at::dtype(at::kLong));
for (int i = 0; i < N; ++i) {
bias_scale.data_ptr<double>()[i] =
pack_ptr.w_scale[i] * input_scale_float;
}
qbias = quantize_linear_per_channel_cpu(
at::dequantize(bias), bias_scale, bias_zp, axis, kQInt32);
} else {
qbias = bias;
TORCH_CHECK(false, "Unsupported quantization scheme.")
}

TORCH_CHECK(qbias.dim() == 1, "bias should be a vector (1D Tensor)");
bias = pack_ptr.bias.value();
bias = bias.contiguous();
TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)");
TORCH_CHECK(
qbias.size(0) == N,
bias.size(0) == N,
"bias should have N elements: " + std::to_string(N));
auto bias_contig = qbias.contiguous();
bias_ptr =
reinterpret_cast<int32_t*>(bias_contig.data_ptr<c10::qint32>());
bias_ptr = reinterpret_cast<float*>(bias.data_ptr<float>());
}

// The resulting matrix here is 2-D, let's view it with the original
Expand All @@ -165,16 +142,22 @@ class QLinearInt8 final : public torch::OperatorKernel {
// 1) Add in row and column offsets to the rows and columns,
// respectively.
// 2) Add in the bias term.
fbgemm::ReQuantizeOutput<ReluFused> outputProcObj(
/*nextop=*/doNothingObj,
/*C_multiplier=*/output_multiplier_float.data(),
/*C_zero_point=*/output_zero_point_int32,
/*Aq_zero_point=*/input_zero_point_int32,
/*Bq_zero_point=*/pack_ptr.w_zp.data(),
/*row_offsets=*/packA.getRowOffsetBuffer(),
/*col_offsets=*/col_offsets.data(),
/*bias=*/bias_ptr,
/*nCol=*/N);
fbgemm::ReQuantizeOutput<
ReluFused,
fbgemm::QuantizationGranularity::TENSOR,
float>
outputProcObj(
doNothingObj,
output_multiplier_float.data(),
output_zero_point_int32,
input_zero_point_int32,
pack_ptr.w_zp.data(),
packA.getRowOffsetBuffer(),
col_offsets.data(),
bias_ptr,
N, /* nCol */
1 /* groups */,
act_times_w_scale.data());

// Do the GEMM
fbgemm::fbgemmPacked(
Expand All @@ -196,17 +179,20 @@ class QLinearInt8 final : public torch::OperatorKernel {
// 2) Add in the bias term.
fbgemm::ReQuantizeOutput<
ReluFused,
fbgemm::QuantizationGranularity::OUT_CHANNEL>
fbgemm::QuantizationGranularity::OUT_CHANNEL,
float>
outputProcObj(
/*nextop=*/doNothingObj,
/*C_multiplier=*/output_multiplier_float.data(),
/*C_zero_point=*/output_zero_point_int32,
/*Aq_zero_point=*/input_zero_point_int32,
/*Bq_zero_point=*/pack_ptr.w_zp.data(),
/*row_offsets=*/packA.getRowOffsetBuffer(),
/*col_offsets=*/col_offsets.data(),
/*bias=*/bias_ptr,
/*nCol=*/N);
doNothingObj,
output_multiplier_float.data(),
output_zero_point_int32,
input_zero_point_int32,
pack_ptr.w_zp.data(),
packA.getRowOffsetBuffer(),
col_offsets.data(),
bias_ptr,
N, /*nCol=*/
1, /* groups*/
act_times_w_scale.data());

// Do the GEMM
fbgemm::fbgemmPacked(
Expand Down
12 changes: 4 additions & 8 deletions test/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,8 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias,
Y_zp = 5

# Weight prepacking operator for quantized Linear
W_prepack = qlinear_prepack(W_q, b_q)
float_bias = b if use_bias else None
W_prepack = qlinear_prepack(W_q, float_bias)

if use_multi_dim_input:
X_q = X_q.view(3, int(batch_size / 3), input_channels)
Expand Down Expand Up @@ -1104,16 +1105,11 @@ def test_qconv(
W_zero_points_tensor.to(dtype=torch.long),
[0],
dtype=torch.qint8)
b_q = torch.quantize_linear_per_channel(b,
X_scale * W_scales_tensor.to(dtype=torch.double),
torch.zeros(output_channels, dtype=torch.long),
[0],
dtype=torch.qint32) if use_bias else None
else:
W_q = torch.quantize_linear(W_KRSC, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8)
b_q = torch.quantize_linear(b, scale=X_scale * W_scale[0], zero_point=0, dtype=torch.qint32) if use_bias else None

W_prepack = qconv_prepack(W_q, b_q, stride, pad, dilation, groups)
bias_float = b if use_bias else None
W_prepack = qconv_prepack(W_q, bias_float, stride, pad, dilation, groups)

Y_q = qconv(
X_q,
Expand Down
Loading

0 comments on commit 2b52c1d

Please sign in to comment.