This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MXNET-290] MKLDNN support for model quantization (#10433)
* mkldnn support for quantization * fix output number in graph * update licsence * modify Jenkinsfile * modify Jenkinsfile * mkldnn has no int8 fc api, excluded_sym_names includes fc for cpu * add mkldnn uint8 pass for quantization graph * update ut * retrig ic * remove no mkldnn quantization test temp * seperate mkldnn quantization ut from gpu quantization ut * rm dev_id check for cpu * add mkl tests dictionary * resolve review comments * simplify DequantizeStorageType() logic * simplify quantize/quantized_conv storage type logic * Add mkldnn_OIhw4i16o4i type case (needed by int8) * INT8 conv/pooling: share with FP32 convolution/pooling class/function * minor indent changes * Remove unnecessary mkldnn_quantized_pooling-inl.h * Fix minor issue * Fix lint * delete duplicated data type * fix bugs and convert requantize data to NDArray * fix lint * fix requantize storgetype * fix requantize storge type * Fix coding style comments * Fix compile issue * Change to use quantized_dtype option to support uint8/int8 scenarios * fix gpu test quantization failure * Fix indent * fix quantized pooling param parser * Fix imagenet_gen_qsym.py option style * retrigger jenkins * retrigger again * trigger jenkins * Resolve further comments * share test code * remove unnecessary test code * add test_quantize_model for cpu * add comments in quantize_graph_pass.cc * jenkins * jenkins * improve coding style * improve coding style * Add naive CPU quantization test back and share quantization code between naive-CPU/MKLDNN/GPU * rename test_quantization_cpu.py to test_quantization_mkldnn.py * code style * trigger * Adjust variable naming for test quantization * add qdtype for quantized op test case to test/bypass all cases explicitly * change expressions to be consistent * revert unnecessary change
- Loading branch information
1 parent
eb95d7b
commit d79e1ad
Showing
27 changed files
with
1,185 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file mkldnn_convolution-inl.h | ||
* \brief | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_ | ||
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_ | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
|
||
#include <utility> | ||
#include "../convolution-inl.h" | ||
#include "./mkldnn_ops-inl.h" | ||
#include "./mkldnn_base-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
mkldnn::convolution_forward::primitive_desc GetConvFwdImpl( | ||
const ConvolutionParam& param, const bool is_train, const NDArray &data, | ||
const NDArray &weights, const NDArray *bias, const NDArray &output); | ||
|
||
class MKLDNNConvForward { | ||
public: | ||
mkldnn::convolution_forward::primitive_desc fwd_pd; | ||
|
||
MKLDNNConvForward(const ConvolutionParam& param, const bool is_train, | ||
const NDArray &data, const NDArray &weights, | ||
const NDArray *bias, const NDArray &output): fwd_pd( | ||
GetConvFwdImpl(param, is_train, data, weights, bias, output)) { | ||
} | ||
|
||
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, | ||
const mkldnn::memory *bias, const mkldnn::memory &output); | ||
|
||
const mkldnn::convolution_forward &GetFwd() const { | ||
return *fwd_; | ||
} | ||
|
||
private: | ||
std::shared_ptr<mkldnn::convolution_forward> fwd_; | ||
std::shared_ptr<mkldnn::memory> data_; | ||
std::shared_ptr<mkldnn::memory> weight_; | ||
std::shared_ptr<mkldnn::memory> bias_; | ||
std::shared_ptr<mkldnn::memory> out_; | ||
}; | ||
|
||
typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature; | ||
|
||
MKLDNNConvForward &GetConvFwd(const nnvm::NodeAttrs& attrs, | ||
const bool is_train, const NDArray &data, const NDArray &weights, | ||
const NDArray *bias, const NDArray &output); | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_USE_MKLDNN == 1 | ||
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONVOLUTION_INL_H_ |
Oops, something went wrong.