Skip to content

Commit 44fa823

Browse files
authored
Merge pull request #9949 from mozga-intel/mozga-intel/Mul_mkldnn
Initial implementation of multiplication operator for MKLDNN
2 parents d67b9ce + 171471e commit 44fa823

File tree

7 files changed

+347
-59
lines changed

7 files changed

+347
-59
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "mkldnn.hpp"
16+
#include "paddle/fluid/framework/tensor.h"
17+
#include "paddle/fluid/operators/mul_op.h"
18+
#include "paddle/fluid/platform/device_context.h"
19+
#include "paddle/fluid/platform/mkldnn_helper.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using paddle::framework::Tensor;
25+
using paddle::platform::MKLDNNDeviceContext;
26+
27+
template <typename Format = mkldnn::memory::format>
28+
mkldnn::memory::desc type(const std::vector<int>& dims, Format&& f) {
29+
return platform::MKLDNNMemDesc(dims, mkldnn::memory::data_type::f32, f);
30+
}
31+
32+
template <typename T>
33+
class MulMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
34+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
35+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
36+
"It must use CPUPlace.");
37+
38+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
39+
auto mkldnn_engine = dev_ctx.GetEngine();
40+
41+
auto input = ctx.Input<Tensor>("X");
42+
auto weight = ctx.Input<Tensor>("Y");
43+
44+
PADDLE_ENFORCE(input->dims().size() & (2 | 4),
45+
"Input must be with 2 or 4 dimensions, i.e. NC or NCHW");
46+
PADDLE_ENFORCE(weight->dims().size() & (2 | 4),
47+
"Weights must be with 2 or 4 dimensions, i.e. OI or OIHW");
48+
49+
std::vector<int> w_tz = paddle::framework::vectorize2int(weight->dims());
50+
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
51+
52+
auto src_md =
53+
src_tz.size() != 2
54+
? type(src_tz, mkldnn::memory::format::nchw)
55+
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
56+
57+
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
58+
59+
auto weights_md =
60+
src_tz.size() != 2
61+
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
62+
mkldnn::memory::format::oihw)
63+
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
64+
65+
auto output = ctx.Output<Tensor>("Out");
66+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
67+
68+
const std::string key = ctx.op().Output("Out");
69+
const std::string key_fc_pd = key + "@mul_pd";
70+
71+
const T* input_data = input->data<T>();
72+
const T* w_data = weight->data<T>();
73+
74+
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
75+
76+
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
77+
platform::to_void_cast(input_data));
78+
79+
auto weights_memory = mkldnn::memory({weights_md, mkldnn_engine},
80+
platform::to_void_cast(w_data));
81+
82+
auto pd = platform::MKLDNNFwdPrimitiveDesc<mkldnn::inner_product_forward>(
83+
mkldnn_engine, src_md, weights_md, dst_md);
84+
85+
dev_ctx.SetBlob(key_fc_pd, pd);
86+
87+
auto forward = mkldnn::inner_product_forward(*pd, src_memory,
88+
weights_memory, dst_memory);
89+
90+
std::vector<mkldnn::primitive> pipeline = {forward};
91+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
92+
}
93+
};
94+
95+
template <typename T>
96+
class MulMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
97+
public:
98+
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
99+
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
100+
"It must use CPUPlace.");
101+
102+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
103+
auto mkldnn_engine = dev_ctx.GetEngine();
104+
105+
const Tensor* input = ctx.Input<Tensor>("X");
106+
const Tensor* w = ctx.Input<Tensor>("Y");
107+
108+
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
109+
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
110+
Tensor* w_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
111+
112+
const std::string key = ctx.op().Input("Out");
113+
const std::string key_fc_pd = key + "@mul_pd";
114+
115+
const T* input_data = input->data<T>();
116+
const T* w_data = w->data<T>();
117+
const T* out_grad_data = out_grad->data<T>();
118+
T* input_grad_data = nullptr;
119+
T* w_grad_data = nullptr;
120+
121+
if (input_grad) {
122+
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
123+
}
124+
if (w_grad) {
125+
w_grad_data = w_grad->mutable_data<T>(ctx.GetPlace());
126+
}
127+
128+
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
129+
std::vector<int> w_tz = paddle::framework::vectorize2int(w->dims());
130+
131+
auto src_md =
132+
src_tz.size() != 2
133+
? type(src_tz, mkldnn::memory::format::nchw)
134+
: type({src_tz[0], src_tz[1]}, mkldnn::memory::format::nc);
135+
136+
auto dst_md = type({src_tz[0], w_tz[1]}, mkldnn::memory::format::nc);
137+
138+
auto weights_md =
139+
src_tz.size() != 2
140+
? type({w_tz[1], src_tz[1], src_tz[2], src_tz[3]},
141+
mkldnn::memory::format::oihw)
142+
: type({w_tz[1], src_tz[1]}, mkldnn::memory::format::oi);
143+
144+
auto src_memory = mkldnn::memory({src_md, mkldnn_engine},
145+
platform::to_void_cast(input_data));
146+
147+
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine},
148+
platform::to_void_cast(out_grad_data));
149+
150+
auto weight_memory = mkldnn::memory({weights_md, mkldnn_engine},
151+
platform::to_void_cast(w_data));
152+
153+
auto pd =
154+
std::static_pointer_cast<mkldnn::inner_product_forward::primitive_desc>(
155+
dev_ctx.GetBlob(key_fc_pd));
156+
157+
PADDLE_ENFORCE(pd != nullptr, "Fail to find pd in device context");
158+
159+
if (w_grad) {
160+
auto weights_grad_memory = mkldnn::memory(
161+
{weights_md, mkldnn_engine}, platform::to_void_cast(w_grad_data));
162+
163+
auto bwd_weight_pd = platform::MKLDNNBwdPrimitiveDesc<
164+
mkldnn::inner_product_backward_weights>(mkldnn_engine, *pd, src_md,
165+
weights_md, dst_md);
166+
167+
auto bwd_weights_prim = mkldnn::inner_product_backward_weights(
168+
bwd_weight_pd, src_memory, dst_memory, weights_grad_memory);
169+
170+
std::vector<mkldnn::primitive> pipeline{bwd_weights_prim};
171+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
172+
}
173+
174+
if (input_grad) {
175+
auto src_grad_memory = mkldnn::memory(
176+
{src_md, mkldnn_engine}, platform::to_void_cast(input_grad_data));
177+
178+
auto bwd_data_pd =
179+
platform::MKLDNNBwdPrimitiveDesc<mkldnn::inner_product_backward_data>(
180+
mkldnn_engine, *pd, src_md, weights_md, dst_md);
181+
182+
auto bwd_data_prim = mkldnn::inner_product_backward_data(
183+
bwd_data_pd, dst_memory, weight_memory, src_grad_memory);
184+
185+
std::vector<mkldnn::primitive> pipeline{bwd_data_prim};
186+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
187+
}
188+
}
189+
};
190+
} // namespace operators
191+
} // namespace paddle
192+
193+
REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace,
194+
paddle::operators::MulMKLDNNOpKernel<float>);
195+
196+
REGISTER_OP_KERNEL(mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
197+
paddle::operators::MulMKLDNNGradOpKernel<float>);

paddle/fluid/operators/mul_op.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/mul_op.h"
16+
#include <string>
1617
#include <vector>
1718

19+
#ifdef PADDLE_WITH_MKLDNN
20+
#include "paddle/fluid/platform/mkldnn_helper.h"
21+
#endif
22+
1823
namespace paddle {
1924
namespace operators {
2025

@@ -71,6 +76,22 @@ class MulOp : public framework::OperatorWithKernel {
7176
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
7277
ctx->ShareLoD("X", /*->*/ "Out");
7378
}
79+
80+
private:
81+
framework::OpKernelType GetExpectedKernelType(
82+
const framework::ExecutionContext& ctx) const override {
83+
framework::LibraryType library{framework::LibraryType::kPlain};
84+
#ifdef PADDLE_WITH_MKLDNN
85+
if (library == framework::LibraryType::kPlain &&
86+
platform::CanMKLDNNBeUsed(ctx)) {
87+
library = framework::LibraryType::kMKLDNN;
88+
}
89+
#endif
90+
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
91+
return framework::OpKernelType(
92+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
93+
layout, library);
94+
}
7495
};
7596

7697
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -100,6 +121,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
100121
)DOC")
101122
.SetDefault(1)
102123
.EqualGreaterThan(1);
124+
AddAttr<bool>("use_mkldnn",
125+
"(bool, default false) Only used in mkldnn kernel")
126+
.SetDefault(false);
103127
AddAttr<int>(
104128
"y_num_col_dims",
105129
R"DOC((int, default 1), The mul_op can take tensors with more than two,
@@ -154,6 +178,22 @@ class MulGradOp : public framework::OperatorWithKernel {
154178
ctx->SetOutputDim(y_grad_name, y_dims);
155179
}
156180
}
181+
182+
private:
183+
framework::OpKernelType GetExpectedKernelType(
184+
const framework::ExecutionContext& ctx) const override {
185+
framework::LibraryType library{framework::LibraryType::kPlain};
186+
#ifdef PADDLE_WITH_MKLDNN
187+
if (library == framework::LibraryType::kPlain &&
188+
platform::CanMKLDNNBeUsed(ctx)) {
189+
library = framework::LibraryType::kMKLDNN;
190+
}
191+
#endif
192+
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
193+
return framework::OpKernelType(
194+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
195+
layout, library);
196+
}
157197
};
158198

159199
} // namespace operators

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414
#pragma once
1515

16+
#include <mkldnn.h>
1617
#include <vector>
17-
18-
#include "mkldnn/include/mkldnn.hpp"
1918
#include "paddle/fluid/framework/operator.h"
2019

2120
namespace paddle {
@@ -34,6 +33,32 @@ typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
3433
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
3534
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
3635

36+
template <typename Type>
37+
void* to_void_cast(const Type* t) {
38+
return static_cast<void*>(const_cast<Type*>(t));
39+
}
40+
41+
template <class Type>
42+
using tf_desc = typename Type::desc;
43+
44+
template <class Type>
45+
using tf_pd = typename Type::primitive_desc;
46+
47+
template <typename Type, typename Engine, typename... Args>
48+
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
49+
Args&&... args) {
50+
auto desc = tf_desc<Type>(mkldnn::prop_kind::forward, (args)...);
51+
auto pd = new tf_pd<Type>(desc, e);
52+
return std::shared_ptr<tf_pd<Type>>(pd);
53+
}
54+
55+
template <typename Type, typename Engine, typename Primitive, typename... Args>
56+
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e, const Primitive& p,
57+
Args&&... args) {
58+
auto desc = tf_desc<Type>(args...);
59+
return tf_pd<Type>(desc, e, p);
60+
}
61+
3762
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int>& dims,
3863
mkldnn::memory::data_type data_type,
3964
mkldnn::memory::format format) {

0 commit comments

Comments
 (0)