This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[mkldnn-v1.0] Add MKL-DNN Pooling #16272
Merged
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
9d2a810
add mkldnn pooling
rongzha1 a22810a
add workaround for mkldnn v1.0 pooling fwd && bwd workspace mismatch
rongzha1 5a960b0
code clean
rongzha1 bbcbbd6
fix lint error
rongzha1 89f4679
trigger CI
rongzha1 96ab0d8
trigger CI
rongzha1 693d8ef
add extra work_space check and fix some typo
rongzha1 22b7fcd
trigger CI
rongzha1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
* \author Tao Lv | ||
*/ | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
#if MXNET_USE_MKLDNN == 100 | ||
|
||
#include "./mkldnn_pooling-inl.h" | ||
|
||
|
@@ -34,18 +34,17 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o | |
const int kernel_h, const int kernel_w, | ||
const int stride_h, const int stride_w, | ||
const int padding_t, const int padding_b, | ||
const int padding_l, const int padding_r) { | ||
// mkldnn::memory::desc | ||
auto src_md = input.GetMKLDNNData()->get_primitive_desc().desc(); | ||
const int padding_l, const int padding_r, | ||
const bool is_train, const mkldnn::algorithm alg_kind) { | ||
auto src_md = input.GetMKLDNNData()->get_desc(); | ||
mkldnn::memory::dims dims = {src_md.data.dims[0], | ||
src_md.data.dims[1], | ||
static_cast<int>(output.shape()[2]), | ||
static_cast<int>(output.shape()[3])}; | ||
auto dst_md = mkldnn::memory::desc({dims}, | ||
static_cast<mkldnn::memory::data_type>(src_md.data.data_type), | ||
static_cast<mkldnn::memory::format>(src_md.data.format)); | ||
mkldnn::memory::format_tag::any); | ||
const mkldnn::engine engine = CpuEngine::Get()->get_engine(); | ||
const mkldnn::algorithm alg_kind = this->alg_kind_; | ||
if (alg_kind != mkldnn::algorithm::pooling_max && | ||
alg_kind != mkldnn::algorithm::pooling_avg && | ||
alg_kind != mkldnn::algorithm::pooling_avg_include_padding && | ||
|
@@ -54,10 +53,10 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o | |
} | ||
|
||
mkldnn::prop_kind prop = mkldnn::prop_kind::forward_scoring; | ||
if (this->is_train_ && alg_kind != mkldnn::algorithm::pooling_avg) { | ||
if (is_train && alg_kind != mkldnn::algorithm::pooling_avg) { | ||
prop = mkldnn::prop_kind::forward_training; | ||
} | ||
if (this->is_train_ && prop == mkldnn::prop_kind::forward_scoring) { | ||
if (is_train && prop == mkldnn::prop_kind::forward_scoring) { | ||
LOG(INFO) << "MKLDNN Pooling: training with prop_kind is forward_scoring"; | ||
} | ||
|
||
|
@@ -67,49 +66,38 @@ void MKLDNNPoolingFwd::Init(const mxnet::NDArray &input, const mxnet::NDArray &o | |
const mkldnn::memory::dims kernel = {kernel_h, kernel_w }; | ||
// mkldnn::pooling_forward::desc | ||
const auto fwd_desc = mkldnn::pooling_forward::desc(prop, alg_kind, src_md, dst_md, | ||
strides, kernel, pad_l, pad_r, | ||
mkldnn::padding_kind::zero); | ||
strides, kernel, pad_l, pad_r); | ||
this->fwd_pd_.reset(new mkldnn::pooling_forward::primitive_desc(fwd_desc, engine)); | ||
this->data_.reset(new mkldnn::memory(input.GetMKLDNNData()->get_primitive_desc())); | ||
this->out_.reset(new mkldnn::memory(this->fwd_pd_->dst_primitive_desc())); | ||
if (this->with_workspace_) { | ||
this->workspace_.reset(new mkldnn::memory(this->fwd_pd_->workspace_primitive_desc())); | ||
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), | ||
mkldnn::primitive::at(*(this->data_)), | ||
*(this->out_), | ||
*(this->workspace_))); | ||
} else { | ||
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_), | ||
mkldnn::primitive::at(*(this->data_)), | ||
*(this->out_))); | ||
} | ||
this->fwd_.reset(new mkldnn::pooling_forward(*(this->fwd_pd_))); | ||
|
||
return; | ||
} | ||
|
||
void MKLDNNPoolingFwd::SetNewMem(const NDArray& in_data, | ||
const NDArray& out_data, | ||
const OpReqType& req, | ||
const mxnet::NDArray *workspace) { | ||
auto input_mem = in_data.GetMKLDNNData(); | ||
output_mem_t_ = CreateMKLDNNMem(out_data, fwd_pd_->dst_primitive_desc(), req); | ||
// mkldnn::memory | ||
this->data_->set_data_handle(input_mem->get_data_handle()); | ||
this->out_->set_data_handle(output_mem_t_.second->get_data_handle()); | ||
if (this->with_workspace_ && workspace == nullptr) { | ||
LOG(FATAL) << "MKLDNN Pooling: incorrect workspace input"; | ||
} | ||
void MKLDNNPoolingFwd::Execute(const NDArray &in_data, | ||
const OpReqType req, | ||
const NDArray& out_data, | ||
const NDArray *workspace) { | ||
NDArray in_buffer = in_data; | ||
if (in_data.IsView() && in_data.IsMKLDNNData()) | ||
in_buffer = in_data.Reorder2Default(); | ||
|
||
auto input_mem = in_buffer.GetMKLDNNData(); | ||
auto output_mem_t_ = CreateMKLDNNMem(out_data, this->fwd_pd_->dst_desc(), req); | ||
|
||
mkldnn_args_map_t args = { | ||
{MKLDNN_ARG_SRC, *input_mem }, | ||
{MKLDNN_ARG_DST, *(output_mem_t_.second) }, | ||
}; | ||
|
||
if (this->with_workspace_) { | ||
// mkldnn::memory | ||
auto ws_mem = workspace->GetMKLDNNData(); | ||
this->workspace_->set_data_handle(ws_mem->get_data_handle()); | ||
auto engine = CpuEngine::Get()->get_engine(); | ||
auto ws = std::make_shared<mkldnn::memory>((*(this->fwd_pd_)).workspace_desc(), | ||
engine, workspace->GetMKLDNNData()->get_data_handle()); | ||
args[MKLDNN_ARG_WORKSPACE] = *ws; | ||
} | ||
} | ||
|
||
void MKLDNNPoolingFwd::Execute(const NDArray& out_data) { | ||
if (this->fwd_) { | ||
MKLDNNStream::Get()->RegisterPrim(*(this->fwd_)); | ||
CommitOutput(out_data, this->output_mem_t_); | ||
MKLDNNStream::Get()->RegisterPrimArgs(*(this->fwd_), args); | ||
CommitOutput(out_data, output_mem_t_); | ||
MKLDNNStream::Get()->Submit(); | ||
} else { | ||
LOG(FATAL) << "MKLDNN Pooling: forward primitive is nullptr"; | ||
|
@@ -143,8 +131,8 @@ static inline int GetPaddingSizeFull(int x, int padl, int padr, int k, int s) { | |
} | ||
|
||
mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( | ||
const PoolingParam ¶m, const bool is_train, const memory::desc &data_md, | ||
const memory::desc &out_md) { | ||
const PoolingParam ¶m, const bool is_train, const mkldnn::memory::desc &data_md, | ||
const mkldnn::memory::desc &out_md) { | ||
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; | ||
int kernel_h_, kernel_w_; | ||
if (param.global_pool) { | ||
|
@@ -183,19 +171,18 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( | |
|
||
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); | ||
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; | ||
if (is_train && alg != algorithm::pooling_avg) { | ||
if (is_train && alg != mkldnn::algorithm::pooling_avg) { | ||
kind = mkldnn::prop_kind::forward_training; | ||
} | ||
|
||
const pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, | ||
const mkldnn::pooling_forward::desc poolingFwd_desc(kind, alg, data_md, out_md, | ||
{static_cast<int>(stride_h_), | ||
static_cast<int>(stride_w_)}, | ||
{kernel_h_, kernel_w_}, | ||
{static_cast<int>(pad_t_), | ||
static_cast<int>(pad_l_)}, | ||
{static_cast<int>(pad_b_), | ||
static_cast<int>(pad_r_)}, | ||
padding_kind::zero); | ||
static_cast<int>(pad_r_)}); | ||
return mkldnn::pooling_forward::primitive_desc(poolingFwd_desc, engine); | ||
} | ||
|
||
|
@@ -223,7 +210,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, | |
auto it = pooling_fwds.find(key); | ||
if (it == pooling_fwds.end()) { | ||
CHECK_EQ(param.kernel.ndim(), 2) << "Not Implemented"; | ||
auto data_md = data.GetMKLDNNData()->get_primitive_desc().desc(); | ||
auto data_md = data.GetMKLDNNData()->get_desc(); | ||
int kernel_h_, kernel_w_; | ||
if (param.global_pool) { | ||
kernel_h_ = data_md.data.dims[2]; | ||
|
@@ -270,42 +257,14 @@ void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, | |
const NDArray &in_data, const OpReqType req, | ||
const NDArray &out_data, const NDArray *workspace) { | ||
auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); | ||
fwd.SetNewMem(in_data, out_data, req, workspace); | ||
fwd.Execute(out_data); | ||
fwd.Execute(in_data, req, out_data, workspace); | ||
} | ||
|
||
MKLDNNPoolingBwd::MKLDNNPoolingBwd( | ||
const pooling_backward::primitive_desc &pdesc, bool with_ws) | ||
: with_workspace(with_ws), pd(pdesc) {} | ||
|
||
void MKLDNNPoolingBwd::SetNewMem(const mxnet::NDArray *workspace, | ||
const mxnet::NDArray &out_grad, | ||
const mkldnn::memory *diff_src_mem) { | ||
if (bwd == nullptr) { | ||
diff_dst.reset( | ||
new mkldnn::memory(out_grad.GetMKLDNNData()->get_primitive_desc(), | ||
out_grad.GetMKLDNNData()->get_data_handle())); | ||
diff_src.reset(new mkldnn::memory(pd.diff_src_primitive_desc(), | ||
diff_src_mem->get_data_handle())); | ||
if (with_workspace) { | ||
CHECK(workspace != nullptr); | ||
ws.reset( | ||
new mkldnn::memory(workspace->GetMKLDNNData()->get_primitive_desc(), | ||
workspace->GetMKLDNNData()->get_data_handle())); | ||
bwd.reset( | ||
new pooling_backward(pd, *diff_dst, primitive::at(*ws), *diff_src)); | ||
} else { | ||
bwd.reset(new pooling_backward(pd, *diff_dst, *diff_src)); | ||
} | ||
} else { | ||
diff_dst->set_data_handle(out_grad.GetMKLDNNData()->get_data_handle()); | ||
diff_src->set_data_handle(diff_src_mem->get_data_handle()); | ||
if (with_workspace) { | ||
CHECK(workspace != nullptr); | ||
ws->set_data_handle(workspace->GetMKLDNNData()->get_data_handle()); | ||
const mkldnn::pooling_backward::primitive_desc &pdesc, bool with_ws) | ||
: with_workspace(with_ws), pd(pdesc) { | ||
bwd = std::make_shared<mkldnn::pooling_backward>(pd); | ||
} | ||
} | ||
} | ||
|
||
const mkldnn::pooling_backward &MKLDNNPoolingBwd::GetBwd() { | ||
return *this->bwd; | ||
|
@@ -333,27 +292,31 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, | |
|
||
auto it = pooling_bwds.find(key); | ||
if (it == pooling_bwds.end()) { | ||
auto diff_dst_mem = out_grad.GetMKLDNNData(); | ||
// mkldnn v1.0 add reoder to workaround testcase:test_make_subgraph; | ||
// alread fixed in v1.1, will remove after v1.1 is integrated. | ||
NDArray diff_dst_buff = out_grad; | ||
if (in_data.IsMKLDNNData() == false && diff_dst_buff.IsMKLDNNData() == true) { | ||
diff_dst_buff = out_grad.Reorder2Default(); | ||
} | ||
auto diff_dst_mem = diff_dst_buff.GetMKLDNNData(); | ||
auto input_mem = in_data.GetMKLDNNData(); | ||
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); | ||
const mkldnn::memory::desc data_md = data_mpd.desc(); | ||
const memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], | ||
const mkldnn::memory::desc data_md = input_mem->get_desc(); | ||
const mkldnn::memory::dims dims = {data_md.data.dims[0], data_md.data.dims[1], | ||
static_cast<int>(out_grad.shape()[2]), | ||
static_cast<int>(out_grad.shape()[3])}; | ||
const memory::desc out_md( | ||
{dims}, static_cast<memory::data_type>(data_md.data.data_type), | ||
static_cast<memory::format>(data_md.data.format)); | ||
const mkldnn::memory::desc out_md( | ||
{dims}, static_cast<mkldnn::memory::data_type>(data_md.data.data_type), | ||
mkldnn::memory::format_tag::any); | ||
auto fwd_pd = GetPoolingFwdPdesc(param, true, data_md, out_md); | ||
|
||
const mkldnn::memory::desc diff_md = | ||
diff_dst_mem->get_primitive_desc().desc(); | ||
const memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], | ||
diff_dst_mem->get_desc(); | ||
const mkldnn::memory::dims dims1 = {diff_md.data.dims[0], diff_md.data.dims[1], | ||
static_cast<int>(in_grad.shape()[2]), | ||
static_cast<int>(in_grad.shape()[3])}; | ||
const memory::desc diff_in_md( | ||
{dims1}, static_cast<memory::data_type>(diff_md.data.data_type), | ||
static_cast<memory::format>(diff_md.data.format)); | ||
const mkldnn::engine cpu_engine = data_mpd.get_engine(); | ||
const mkldnn::memory::desc diff_in_md( | ||
{dims1}, static_cast<mkldnn::memory::data_type>(diff_md.data.data_type), | ||
mkldnn::memory::format_tag::any); | ||
const mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();; | ||
const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); | ||
|
||
int kernel_h_, kernel_w_; | ||
|
@@ -379,11 +342,10 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, | |
stride_h_ = stride_w_ = 1; | ||
} | ||
|
||
const pooling_backward::desc desc( | ||
const mkldnn::pooling_backward::desc desc( | ||
alg, diff_in_md, diff_md, {stride_h_, stride_w_}, | ||
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}, | ||
mkldnn::padding_kind::zero); | ||
const auto pdesc = pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); | ||
{kernel_h_, kernel_w_}, {pad_t_, pad_l_}, {pad_b_, pad_r_}); | ||
const auto pdesc = mkldnn::pooling_backward::primitive_desc(desc, cpu_engine, fwd_pd); | ||
MKLDNNPoolingBwd bwd(pdesc, with_workspace); | ||
it = AddToCache(&pooling_bwds, key, bwd); | ||
} | ||
|
@@ -401,14 +363,21 @@ void MKLDNNPoolingGradCompute(const OpContext &ctx, const PoolingParam ¶m, | |
|
||
auto &bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); | ||
auto diff_src_mem = | ||
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req); | ||
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req); | ||
|
||
mkldnn_args_map_t args = { | ||
{MKLDNN_ARG_DIFF_DST, *(out_grad.GetMKLDNNData())}, | ||
{MKLDNN_ARG_DIFF_SRC, *diff_src_mem.second }, | ||
}; | ||
if (workspace != nullptr) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. Done |
||
args[MKLDNN_ARG_WORKSPACE] = *(workspace->GetMKLDNNData()); | ||
} | ||
|
||
bwd.SetNewMem(workspace, out_grad, diff_src_mem.second); | ||
MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); | ||
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), args); | ||
CommitOutput(in_grad, diff_src_mem); | ||
MKLDNNStream::Get()->Submit(); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
#endif // MXNET_USE_MKLDNN == 1 | ||
#endif // MXNET_USE_MKLDNN == 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done