Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MKLDNN Backward op cache #11301

Merged
merged 27 commits into from
Sep 13, 2018
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7c9e8e6
Merge pull request #1 from apache/master
ZhennanQin Jun 21, 2018
a35bc9a
Merge remote-tracking branch 'upstream/master'
ZhennanQin Jun 23, 2018
ad90147
Merge remote-tracking branch 'upstream/master'
ZhennanQin Jun 29, 2018
f3f09b7
Enable primitive allocation cache for _backward_LRN.
ZhennanQin Jun 13, 2018
d6dc8a8
Enable primitive allocation cache for _backward_Pooling.
ZhennanQin Jun 13, 2018
9e107d2
Enable primitive allocation cache for _backward_Activation.
ZhennanQin Jun 12, 2018
b2b71e1
Enable primitive allocation cache for _backward_Deconvolution.
ZhennanQin Jun 13, 2018
a58ad33
Enable primitive allocation cache for _backward_BatchNorm.
ZhennanQin Jun 13, 2018
97e0d34
Enable primitive allocation cache for _backward_Convolution
Jun 13, 2018
f7b9d30
Enable primitive allocation cache for _backward_Fully_Connected
Jun 13, 2018
09ab93a
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 2, 2018
e9f6a33
remove fc forward and fix indent problem
huangzhiyuan Jul 9, 2018
2f3f436
remove fc forward and fix convolution indent problem
huangzhiyuan Jul 9, 2018
315abb8
Change log level to FATAL for unreachable code in mkldnn_act.cc
ZhennanQin Jul 9, 2018
21b1a68
remove fc forward and fix convolution indent problem
huangzhiyuan Jul 11, 2018
dea6f91
remove useless hint in fc
huangzhiyuan Jul 11, 2018
dee9bd6
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 12, 2018
dd07d9f
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 13, 2018
f160c11
Merge branch 'master' into backward_op_cache
huangzhiyuan Jul 13, 2018
89bafa8
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 16, 2018
913a143
Empty commit to retrigger the CI.
ZhennanQin Jul 16, 2018
75039e1
Change LOG(INFO) to LOG(FATAL) for unreachable code in mkldnn_act.cc
ZhennanQin Jul 16, 2018
c8e976f
Merge branch 'master' into backward_op_cache
ZhennanQin Jul 25, 2018
d92915b
Fix build issue after code merge.
ZhennanQin Jul 25, 2018
e0805c8
Merge remote-tracking branch 'upstream/master' into backward_op_cache
ZhennanQin Aug 27, 2018
ae4a749
Fix lint after merge
ZhennanQin Aug 27, 2018
c34c603
Fix mkldnn act.
ZhennanQin Aug 31, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Enable primitive allocation cache for _backward_Convolution
Change-Id: I0496fa2394ee036d05c58f3abc1d74af544c7bca
  • Loading branch information
Huang, Zhiyuan authored and ZhennanQin committed Jun 29, 2018
commit 97e0d34da384c3e245dd97c68e23696a59c0aa70
205 changes: 178 additions & 27 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,157 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
MKLDNNStream::Get()->Submit();
}

class MKLDNNConvBackward {
std::shared_ptr<mkldnn::convolution_backward_data> bwd_data;
std::shared_ptr<mkldnn::convolution_backward_weights> bwd_weight;
// conv::kData
std::shared_ptr<mkldnn::memory> out_grad;
std::shared_ptr<mkldnn::memory> in_grad;
std::shared_ptr<mkldnn::memory> weight;
// conv::kWeight
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> output;
std::shared_ptr<mkldnn::memory> in_grad_weight;
std::shared_ptr<mkldnn::memory> in_grad_bias;

public:
mkldnn::convolution_backward_data::primitive_desc bwdData_pd;
mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd;

MKLDNNConvBackward(
const ConvolutionParam &param, const NDArray &data,
const NDArray &weights, const NDArray *bias, const NDArray &output,
const mkldnn::convolution_forward::primitive_desc &fwd_pd):
bwdData_pd(GetConvBwdData(param, data, weights, output, fwd_pd)),
bwdWeights_pd(GetConvBwdWeights(param, data, weights, bias, output, fwd_pd)) {
}

void SetDataNewMem(const mkldnn::memory &out_grad, const mkldnn::memory &weight,
const mkldnn::memory &in_grad) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

if (this->out_grad == nullptr)
this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
else
this->out_grad->set_data_handle(out_grad.get_data_handle());
if (this->in_grad == nullptr)
this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle()));
else
this->in_grad->set_data_handle(in_grad.get_data_handle());
if (this->weight == nullptr)
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdData_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight->set_data_handle(weight.get_data_handle());
if (this->bwd_data == nullptr)
this->bwd_data = std::shared_ptr<mkldnn::convolution_backward_data>(
new mkldnn::convolution_backward_data(
this->bwdData_pd, mkldnn::primitive::at(*this->out_grad),
mkldnn::primitive::at(*this->weight), *this->in_grad));
}

void SetWeightNewMem(const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent

if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());
if (this->output == nullptr)
this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
else
this->output->set_data_handle(out_grad.get_data_handle());
if (this->in_grad_weight == nullptr)
this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
in_grad_weight.get_data_handle()));
else
this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());

if (this->bwd_weight == nullptr)
this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
new mkldnn::convolution_backward_weights(
this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->output), *this->in_grad_weight));
}

void SetWeightNewMem(const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight,
const mkldnn::memory &in_grad_bias) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());
if (this->output == nullptr)
this->output = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
bwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
else
this->output->set_data_handle(out_grad.get_data_handle());
if (this->in_grad_weight == nullptr)
this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(bwdWeights_pd.diff_weights_primitive_desc(),
in_grad_weight.get_data_handle()));
else
this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());

if (this->in_grad_bias == nullptr)
this->in_grad_bias = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(bwdWeights_pd.diff_bias_primitive_desc(),
in_grad_bias.get_data_handle()));
else
this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle());
if (this->bwd_weight == nullptr)
this->bwd_weight = std::shared_ptr<mkldnn::convolution_backward_weights>(
new mkldnn::convolution_backward_weights(
this->bwdWeights_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->output), *this->in_grad_weight,
*this->in_grad_bias));
}

const mkldnn::convolution_backward_data &GetBwdData() const {
return *bwd_data;
}

const mkldnn::convolution_backward_weights &GetBwdWeights() const {
return *bwd_weight;
}
};

static inline MKLDNNConvBackward &GetConvBwd(
const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weights,
const NDArray *bias, const NDArray &output,
const mkldnn::convolution_forward::primitive_desc &fwd_pd) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNConvSignature, MKLDNNConvBackward, OpHash> bwds;
#endif
const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
MKLDNNConvSignature key(param);
// Here we can sign the conv op with NDArray because conv primitive will
// decide the right layout for the, so we only need to get the shape and the
// data type of the arrays.
key.AddSign(data);
key.AddSign(weights);
key.AddSign(output);
if (bias)
key.AddSign(*bias);

auto it = bwds.find(key);
if (it == bwds.end()) {
MKLDNNConvBackward bwd(param, data, weights, bias, output, fwd_pd);
auto ins_ret = bwds.insert(
std::pair<MKLDNNConvSignature, MKLDNNConvBackward>(key, bwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
Expand All @@ -295,44 +446,45 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
param.no_bias ? nullptr : &inputs[conv::kBias + 1], inputs[conv::kOut]);

CHECK_NE(req[conv::kWeight], kWriteInplace) << "cannot write weight inplace";
mkldnn::convolution_backward_data::primitive_desc bwdData_pd
= GetConvBwdData(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
inputs[conv::kOut], fwd_pd);
MKLDNNConvBackward &convBwd = GetConvBwd(attrs, inputs[conv::kData + 1],
inputs[conv::kWeight + 1], nullptr, inputs[conv::kOut], fwd_pd);
auto out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
bwdData_pd.diff_dst_primitive_desc());
convBwd.bwdData_pd.diff_dst_primitive_desc());
if (req[conv::kData]) {
auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
bwdData_pd.weights_primitive_desc(), param.num_group);
convBwd.bwdData_pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd,
*out_grad_mem, *weight_mem, *in_grad_mem.second));
convBwd.bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
convBwd.SetDataNewMem(*out_grad_mem, *weight_mem, *in_grad_mem.second);
MKLDNNStream::Get()->RegisterPrim(convBwd.GetBwdData());
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight]) {
mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
= GetConvBwdWeights(param, inputs[conv::kData + 1], inputs[conv::kWeight + 1],
param.no_bias ? nullptr : &inputs[conv::kBias + 1],
inputs[conv::kOut], fwd_pd);
if (bwdData_pd.diff_dst_primitive_desc() != bwdWeights_pd.diff_dst_primitive_desc())
MKLDNNConvBackward &convBwdWeight = GetConvBwd(attrs, inputs[conv::kData + 1],
inputs[conv::kWeight + 1], param.no_bias ? nullptr : &inputs[conv::kBias + 1],
inputs[conv::kOut], fwd_pd);
if (convBwdWeight.bwdData_pd.diff_dst_primitive_desc() !=
convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc())
out_grad_mem = inputs[conv::kOut].GetMKLDNNDataReorder(
bwdWeights_pd.diff_dst_primitive_desc());
convBwdWeight.bwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
bwdWeights_pd.src_primitive_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[conv::kWeight],
bwdWeights_pd.diff_weights_primitive_desc(),
req[conv::kWeight]);
convBwdWeight.bwdWeights_pd.src_primitive_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(
in_grad[conv::kWeight],
convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
req[conv::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
} else {
in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
bwdWeights_pd.diff_bias_primitive_desc(),
req[conv::kBias]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
*in_grad_bias.second));
in_grad_bias = CreateMKLDNNMem(
in_grad[conv::kBias],
convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second, *in_grad_bias.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
CommitOutput(in_grad[conv::kBias], in_grad_bias);
}
CommitOutput(in_grad[conv::kWeight], in_grad_weight);
Expand All @@ -342,5 +494,4 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1