This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN Backward op cache #11301
Merged
Merged
MKLDNN Backward op cache #11301
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
7c9e8e6
Merge pull request #1 from apache/master
ZhennanQin a35bc9a
Merge remote-tracking branch 'upstream/master'
ZhennanQin ad90147
Merge remote-tracking branch 'upstream/master'
ZhennanQin f3f09b7
Enable primitive allocation cache for _backward_LRN.
ZhennanQin d6dc8a8
Enable primitive allocation cache for _backward_Pooling.
ZhennanQin 9e107d2
Enable primitive allocation cache for _backward_Activation.
ZhennanQin b2b71e1
Enable primitive allocation cache for _backward_Deconvolution.
ZhennanQin a58ad33
Enable primitive allocation cache for _backward_BatchNorm.
ZhennanQin 97e0d34
Enable primitive allocation cache for _backward_Convolution
f7b9d30
Enable primitive allocation cache for _backward_Fully_Connected
09ab93a
Merge branch 'master' into backward_op_cache
ZhennanQin e9f6a33
remove fc forward and fix indent problem
huangzhiyuan 2f3f436
remove fc forward and fix convolution indent problem
huangzhiyuan 315abb8
Change log level to FATAL for unreachable code in mkldnn_act.cc
ZhennanQin 21b1a68
remove fc forward and fix convolution indent problem
huangzhiyuan dea6f91
remove useless hint in fc
huangzhiyuan dee9bd6
Merge branch 'master' into backward_op_cache
ZhennanQin dd07d9f
Merge branch 'master' into backward_op_cache
ZhennanQin f160c11
Merge branch 'master' into backward_op_cache
huangzhiyuan 89bafa8
Merge branch 'master' into backward_op_cache
ZhennanQin 913a143
Empty commit to retrigger the CI.
ZhennanQin 75039e1
Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc
ZhennanQin c8e976f
Merge branch 'master' into backward_op_cache
ZhennanQin d92915b
Fix build issue after code merge.
ZhennanQin e0805c8
Merge remote-tracking branch 'upstream/master' into backward_op_cache
ZhennanQin ae4a749
Fix lint after merge
ZhennanQin c34c603
Fix mkldnn act.
ZhennanQin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Enable primitive allocation cache for _backward_Convolution
Change-Id: I0496fa2394ee036d05c58f3abc1d74af544c7bca
- Loading branch information
commit 97e0d34da384c3e245dd97c68e23696a59c0aa70
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx | |
MKLDNNStream::Get()->Submit(); | ||
} | ||
|
||
class MKLDNNConvBackward { | ||
std::shared_ptr<mkldnn::convolution_backward_data> bwd_data; | ||
std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight; | ||
// conv::kData | ||
std::shared_ptr<mkldnn::memory> out_grad; | ||
std::shared_ptr<mkldnn::memory> in_grad; | ||
std::shared_ptr<mkldnn::memory> weight; | ||
// conv::kWeight | ||
std::shared_ptr<mkldnn::memory> data; | ||
std::shared_ptr<mkldnn::memory> output; | ||
std::shared_ptr<mkldnn::memory> in_grad_weight; | ||
std::shared_ptr<mkldnn::memory> in_grad_bias; | ||
|
||
public: | ||
mkldnn::convolution_backward_data::primitive_desc bwdData_pd; | ||
mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd; | ||
|
||
MKLDNNConvBackward( | ||
const ConvolutionParam ¶m, const NDArray &data, | ||
const NDArray &weights, const NDArray *bias, const NDArray &output, | ||
const mkldnn::convolution_forward::primitive_desc &fwd_pd): | ||
bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)), | ||
bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) { | ||
} | ||
|
||
void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight, | ||
const mkldnn::memory &in_grad) { | ||
if (this->out_grad == nullptr) | ||
this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); | ||
else | ||
this->out_grad->set_data_handle(out_grad.get_data_handle()); | ||
if (this->in_grad == nullptr) | ||
this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); | ||
else | ||
this->in_grad->set_data_handle(in_grad.get_data_handle()); | ||
if (this->weight == nullptr) | ||
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdData_pd.weights_primitive_desc(), weight.get_data_handle())); | ||
else | ||
this->weight->set_data_handle(weight.get_data_handle()); | ||
if (this->bwd_data == nullptr) | ||
this->bwd_data = std::shared_ptr<mkldnn::convolution_backward_data>( | ||
new mkldnn::convolution_backward_data( | ||
this->bwdData_pd, mkldnn::primitive::at(*this->out_grad), | ||
mkldnn::primitive::at(*this->weight), *this->in_grad)); | ||
} | ||
|
||
void SetWeightNewMem(const mkldnn::memory &data, | ||
const mkldnn::memory &out_grad, | ||
const mkldnn::memory &in_grad_weight) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent |
||
if (this->data == nullptr) | ||
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); | ||
else | ||
this->data->set_data_handle(data.get_data_handle()); | ||
if (this->output == nullptr) | ||
this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); | ||
else | ||
this->output->set_data_handle(out_grad.get_data_handle()); | ||
if (this->in_grad_weight == nullptr) | ||
this->in_grad_weight = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), | ||
in_grad_weight.get_data_handle())); | ||
else | ||
this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); | ||
|
||
if (this->bwd_weight == nullptr) | ||
this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>( | ||
new mkldnn::convolution_backward_weights( | ||
this->bwdWeights_pd, mkldnn::primitive::at(*this->data), | ||
mkldnn::primitive::at(*this->output), *this->in_grad_weight)); | ||
} | ||
|
||
void SetWeightNewMem(const mkldnn::memory &data, | ||
const mkldnn::memory &out_grad, | ||
const mkldnn::memory &in_grad_weight, | ||
const mkldnn::memory &in_grad_bias) { | ||
if (this->data == nullptr) | ||
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdWeights_pd.src_primitive_desc(), data.get_data_handle())); | ||
else | ||
this->data->set_data_handle(data.get_data_handle()); | ||
if (this->output == nullptr) | ||
this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); | ||
else | ||
this->output->set_data_handle(out_grad.get_data_handle()); | ||
if (this->in_grad_weight == nullptr) | ||
this->in_grad_weight = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(), | ||
in_grad_weight.get_data_handle())); | ||
else | ||
this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); | ||
|
||
if (this->in_grad_bias == nullptr) | ||
this->in_grad_bias = std::shared_ptr<mkldnn::memory>( | ||
new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(), | ||
in_grad_bias.get_data_handle())); | ||
else | ||
this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); | ||
if (this->bwd_weight == nullptr) | ||
this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>( | ||
new mkldnn::convolution_backward_weights( | ||
this->bwdWeights_pd, mkldnn::primitive::at(*this->data), | ||
mkldnn::primitive::at(*this->output), *this->in_grad_weight, | ||
*this->in_grad_bias)); | ||
} | ||
|
||
const mkldnn::convolution_backward_data &GetBwdData() const { | ||
return *bwd_data; | ||
} | ||
|
||
const mkldnn::convolution_backward_weights &GetBwdWeights() const { | ||
return *bwd_weight; | ||
} | ||
}; | ||
|
||
static inline MKLDNNConvBackward &GetConvBwd( | ||
const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights, | ||
const NDArray *bias, const NDArray &output, | ||
const mkldnn::convolution_forward::primitive_desc &fwd_pd) { | ||
#if DMLC_CXX11_THREAD_LOCAL | ||
static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds; | ||
#else | ||
static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds; | ||
#endif | ||
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed); | ||
MKLDNNConvSignature key(param); | ||
// Here we can sign the conv op with NDArray because conv primitive will | ||
// decide the right layout for the, so we only need to get the shape and the | ||
// data type of the arrays. | ||
key.AddSign(data); | ||
key.AddSign(weights); | ||
key.AddSign(output); | ||
if (bias) | ||
key.AddSign(*bias); | ||
|
||
auto it = bwds.find(key); | ||
if (it == bwds.end()) { | ||
MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd); | ||
auto ins_ret = bwds.insert( | ||
std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd)); | ||
CHECK(ins_ret.second); | ||
it = ins_ret.first; | ||
} | ||
return it->second; | ||
} | ||
|
||
void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
|
@@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct | |
param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]); | ||
|
||
CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace"; | ||
mkldnn::convolution_backward_data::primitive_desc bwdData_pd | ||
= GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], | ||
inputs[conv::kOut], fwd_pd); | ||
MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1], | ||
inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd); | ||
auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( | ||
bwdData_pd.diff_dst_primitive_desc()); | ||
convBwd.bwdData_pd.diff_dst_primitive_desc()); | ||
if (req[conv::kData]) { | ||
auto weight_mem = GetWeights(inputs[conv::kWeight + 1], | ||
bwdData_pd.weights_primitive_desc(), param.num_group); | ||
convBwd.bwdData_pd.weights_primitive_desc(), param.num_group); | ||
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], | ||
bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); | ||
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd, | ||
*out_grad_mem, *weight_mem, *in_grad_mem.second)); | ||
convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); | ||
convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second); | ||
MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData()); | ||
CommitOutput(in_grad[conv::kData], in_grad_mem); | ||
} | ||
if (req[conv::kWeight]) { | ||
mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd | ||
= GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1], | ||
param.no_bias ? nullptr : &inputs[conv::kBias + 1], | ||
inputs[conv::kOut], fwd_pd); | ||
if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc()) | ||
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1], | ||
inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1], | ||
inputs[conv::kOut], fwd_pd); | ||
if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() != | ||
convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()) | ||
out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder( | ||
bwdWeights_pd.diff_dst_primitive_desc()); | ||
convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc()); | ||
auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder( | ||
bwdWeights_pd.src_primitive_desc()); | ||
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight], | ||
bwdWeights_pd.diff_weights_primitive_desc(), | ||
req[conv::kWeight]); | ||
convBwdWeight.bwdWeights_pd.src_primitive_desc()); | ||
auto in_grad_weight = CreateMKLDNNWeightGrad( | ||
in_grad[conv::kWeight], | ||
convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(), | ||
req[conv::kWeight]); | ||
mkldnn_output_t in_grad_bias; | ||
if (param.no_bias) { | ||
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( | ||
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); | ||
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, | ||
*in_grad_weight.second); | ||
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); | ||
} else { | ||
in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], | ||
bwdWeights_pd.diff_bias_primitive_desc(), | ||
req[conv::kBias]); | ||
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights( | ||
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, | ||
*in_grad_bias.second)); | ||
in_grad_bias = CreateMKLDNNMem( | ||
in_grad[conv::kBias], | ||
convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); | ||
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem, | ||
*in_grad_weight.second, *in_grad_bias.second); | ||
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights()); | ||
CommitOutput(in_grad[conv::kBias], in_grad_bias); | ||
} | ||
CommitOutput(in_grad[conv::kWeight], in_grad_weight); | ||
|
@@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct | |
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_USE_MKLDNN == 1 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent