Skip to content

Commit

Permalink
Add ONEDNN quantization backend (pytorch#69820)
Browse files Browse the repository at this point in the history
Summary:
This PR adds a new quantization backend, ONEDNN, with quantized conv and linear kernels in the same code path as the FBGEMM backend

The ONEDNN backend is an alternative of FBGEMM and QNNPACK backends. It takes advantage of features of the latest Intel® CPU products. It supports VNNI on Cascade Lake and the AMX instruction set to be available on Sapphire Rapids which has 8X int8 peak TOPS over VNNI.

ONEDNN demonstrates better performance on conv kernels of popular CNN models than FBGEMM. It also supports more fused ops, such as convolution-add-ReLU, than FBGEMM and QNNPACK.
To use this backend, users only need to set the quantization backend to 'onednn' before any calculation without a single change to models.
```python
torch.backends.quantized.engine = 'onednn'
```

## Design docs
pytorch#21120 (comment)
pytorch#67177 (comment)

## File changes
**Add ONEDNN to qengine list**
- aten/src/ATen/Context.cpp
- c10/core/QEngine.h
- torch/ao/quantization/qconfig.py
- torch/backends/quantized/\_\_init\_\_.py

**Implement qconv & qlinear for ONEDNN backend**
- aten/src/ATen/native/quantized/cpu/conv_serialization.h
- aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
- aten/src/ATen/native/quantized/cpu/onednn_utils.h
- aten/src/ATen/native/quantized/cpu/qconv.cpp
- aten/src/ATen/native/quantized/cpu/qconv_dynamic.cpp
- aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp
- aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp
- aten/src/ATen/native/quantized/cpu/qlinear.cpp
- aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp
- aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp
- aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp

**Skip tests that are not supported by ONEDNN**
- test/ao/sparsity/test_kernels.py
- test/quantization/core/test_quantized_module.py
- test/quantization/core/test_quantized_op.py

## Validation results
This PR has passed `test_quantization.py` and `test_mkldnn.py`.
Below are performance data of int8 2d convolution and linear on the Cascade Lake Xeon® platform:
(Note: Tested with single instance on single core. Using the latest oneDNN library.)

**Table 1. Performance comparison of int8 2d convolution operator**
|No.|	Shape|	FBGEMM|	ONEDNN|	Gain|
|-|-|-|-|-|
|1|	IC=128, OC=128, kernel=3, stride=1, N=4, H=32, W=32, G=1, pad=0|	668.310us|	535.630us|	24.8%|
|2|	IC=128, OC=128, kernel=3, stride=2, N=4, H=32, W=32, G=1, pad=0|	290.630us|	281.810us|	3.1%|
|3|	IC=128, OC=256, kernel=3, stride=1, N=4, H=32, W=32, G=1, pad=0|	1.045ms|	893.010us|	17.0%|
|4|	IC=128, OC=256, kernel=3, stride=2, N=4, H=32, W=32, G=1, pad=0|	385.320us|	373.720us|	3.1%|
|5|	IC=256, OC=256, kernel=3, stride=1, N=4, H=32, W=32, G=1, pad=0|	1.876ms|	1.641ms|	14.3%|
|6|	IC=256, OC=256, kernel=3, stride=2, N=4, H=32, W=32, G=1, pad=0|	660.460us|	638.470us|	3.4%|

**Table 2. Performance comparison of int8 linear operator**
|No.|	Shape (m, n, k)|	FBGEMM|	ONEDNN|	Gap|
|-|-|-|-|-|
|1|	64, 800, 320|	80.550us|	96.770us|	20.10%|
|2|	64, 768, 512|	101.230us|	130.720us|	29.10%|
|3|	16, 256, 512|	30.230us|	51.450us|	70.20%|
|4|	128, 128, 128|	33.810us|	50.480us|	49.30%|
|5|	256, 512, 256|	154.490us|	195.050us|	26.30%|
|6|	1024, 1024, 1024|	3.134ms|	3.514ms|	12.10%|

ONEDNN showed advantages over FBGEMM for convolution. However, it has performance gap to FBGEMM for Linear ops. The gap is a known issue and further optimization is in progress in the oneDNN library. On the latest platforms, better performance of ONEDNN is achieved for both conv and linear.

Pull Request resolved: pytorch#69820

Reviewed By: HDCharles

Differential Revision: D33716039

Pulled By: jerryzh168

fbshipit-source-id: 6f7bb807e85798142dfcffccfca8b8bd652fb3dd
(cherry picked from commit 91526b3)
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Mar 11, 2022
1 parent 3e556ef commit 989b248
Show file tree
Hide file tree
Showing 19 changed files with 1,010 additions and 37 deletions.
4 changes: 4 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ const std::vector<at::QEngine>& Context::supportedQEngines() {
engines.push_back(at::kNoQEngine);
#endif // C10_MOBILE

#if AT_MKLDNN_ENABLED()
engines.push_back(at::kONEDNN);
#endif

#ifdef USE_FBGEMM
if (fbgemm::fbgemmSupportedCPU()) {
engines.push_back(at::kFBGEMM);
Expand Down
15 changes: 15 additions & 0 deletions aten/src/ATen/native/quantized/cpu/conv_serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/core/List.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <c10/util/irange.h>

#include <tuple>
Expand Down Expand Up @@ -358,6 +359,20 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
);
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (ctx.qEngine() == at::QEngine::ONEDNN) {
return PackedConvWeightsOnednn<kSpatialDim>::prepack(
weight.value(),
bias,
stride,
padding,
output_padding,
dilation,
groups,
transpose
);
}
#endif // AT_MKLDNN_ENABLED()
TORCH_CHECK(
false,
"Didn't find engine for when deserializing ConvPackedParams: ",
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
#include <ATen/native/quantized/cpu/onednn_utils.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/quantized/QTensorImpl.h>
#include <ATen/quantized/Quantizer.h>
Expand Down Expand Up @@ -470,6 +471,16 @@ int register_linear_params() {
std::move(weight), std::move(bias));
}
#endif // USE_PYTORCH_QNNPACK
#if AT_MKLDNN_ENABLED()
if (at::globalContext().qEngine() == at::QEngine::ONEDNN) {
TORCH_CHECK(
weight.scalar_type() == at::kQInt8,
"ONEDNN only supports INT8 bit width currently. Got ",
c10::toString(weight.scalar_type()));
return PackedLinearWeightsOnednn::prepack(
std::move(weight), std::move(bias));
}
#endif // #if AT_MKLDNN_ENABLED()
TORCH_CHECK(false, "Unknown qengine");
})
.def("bias", [](const c10::intrusive_ptr<LinearPackedParamsBase>& self) {
Expand Down
154 changes: 154 additions & 0 deletions aten/src/ATen/native/quantized/cpu/onednn_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
#pragma once

#include <ATen/Config.h>
#if AT_MKLDNN_ENABLED()
#include <ATen/Tensor.h>
#include <ATen/native/quantized/cpu/conv_packed_params.h>
#include <ATen/native/quantized/cpu/packed_params.h>
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
#include <c10/core/QScheme.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/mkldnn/Utils.h>

struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
PackedLinearWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
c10::optional<ideep::tensor> bias,
at::Tensor orig_weight,
c10::optional<at::Tensor> orig_bias)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)) {}
std::unique_ptr<ideep::tensor> weight_;
c10::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
c10::optional<at::Tensor> orig_bias_;

at::Tensor apply(
at::Tensor input,
double output_scale,
int64_t output_zero_point) override;
at::Tensor apply_relu(
at::Tensor input,
double output_scale,
int64_t output_zero_point) override;

at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;

std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;

c10::optional<at::Tensor> bias() override {
return orig_bias_;
}

static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias);

private:
template <bool ReluFused>
at::Tensor apply_impl(
at::Tensor input,
double output_scale,
int64_t output_zero_point);

template <bool ReluFused>
at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
};

template <int kSpatialDim = 2>
struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightsOnednn(
std::unique_ptr<ideep::tensor> weight,
c10::optional<ideep::tensor> bias,
at::Tensor orig_weight,
c10::optional<at::Tensor> orig_bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
uint8_t transpose)
: weight_(std::move(weight)),
bias_(std::move(bias)),
orig_weight_(std::move(orig_weight)),
orig_bias_(std::move(orig_bias)),
stride_(std::move(stride)),
padding_(std::move(padding)),
output_padding_(std::move(output_padding)),
dilation_(std::move(dilation)),
groups_(groups),
transpose_(transpose) {}

std::unique_ptr<ideep::tensor> weight_;
c10::optional<ideep::tensor> bias_;
at::Tensor orig_weight_;
c10::optional<at::Tensor> orig_bias_;
torch::List<int64_t> stride_;
torch::List<int64_t> padding_;
torch::List<int64_t> output_padding_;
torch::List<int64_t> dilation_;
int64_t groups_;
uint8_t transpose_;

at::Tensor apply(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;

at::Tensor apply_relu(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point) override;

at::Tensor apply_dynamic(
const at::Tensor& input,
bool reduce_range) override;

std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;

static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
at::Tensor weight,
c10::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> output_padding,
torch::List<int64_t> dilation,
int64_t groups,
bool transpose);

torch::List<int64_t> stride() const override {
return stride_;
}

torch::List<int64_t> padding() const override {
return padding_;
}

torch::List<int64_t> output_padding() const override {
return output_padding_;
}

torch::List<int64_t> dilation() const override {
return dilation_;
}

int64_t groups() const override {
return groups_;
}

bool transpose() const override {
return (bool)transpose_;
}

private:
template <bool ReluFused>
at::Tensor apply_impl(
const at::Tensor& input,
double output_scale,
int64_t output_zero_point);
};

#endif // #if AT_MKLDNN_ENABLED()
Loading

0 comments on commit 989b248

Please sign in to comment.