Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Test/mkl dnn act #11026

Merged
merged 59 commits into from
May 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
85887db
add mkl act unit test
azai91 May 18, 2018
9e5bcf3
fix operator name
azai91 May 18, 2018
c604579
use custom ndarray init
azai91 May 18, 2018
3b6a952
func missing param
azai91 May 19, 2018
c36ffa4
add init fn for act operator test
azai91 May 22, 2018
b791209
remove extra white space
azai91 May 22, 2018
7253dc4
fix fetch relu operator
azai91 May 22, 2018
8f616c0
fix get relu operator name
azai91 May 22, 2018
8ed4074
add assert abs in verify fn
azai91 May 22, 2018
a3ace0d
remove unused operator
azai91 May 22, 2018
a8521f6
cast blob ptr to float
azai91 May 22, 2018
a164093
use parsed param
azai91 May 23, 2018
ff2908c
use attr_parser
azai91 May 23, 2018
60dc60e
fix header order
azai91 May 23, 2018
0a00f4b
update test fn name
azai91 May 23, 2018
a30f479
use relu fn
azai91 May 23, 2018
944bc7c
add kFComputeEx dispatch
azai91 May 23, 2018
bbdaf80
init posneg mklarray
azai91 May 23, 2018
739be73
fix generating rnd pos neg ints
azai91 May 24, 2018
5667e2b
output arrays are rnd generated
azai91 May 24, 2018
cffd5f6
test that getinputarrays creates view and mkldnn arrays
azai91 May 24, 2018
5db995e
add more output types
azai91 May 24, 2018
c2dcf82
fix typo
azai91 May 24, 2018
3b6aa73
fix gettestput test
azai91 May 24, 2018
e553ea4
create arrattr struct to display arr info
azai91 May 24, 2018
5955fa5
refactor initarray
azai91 May 24, 2018
bb1de85
print arr description in verify fn
azai91 May 24, 2018
43b9b8a
use long int string interpolation
azai91 May 24, 2018
1168d2e
fix alias params
azai91 May 24, 2018
75e2248
iterate over dims
azai91 May 24, 2018
b056bd4
print c_str
azai91 May 24, 2018
18a73af
print output info
azai91 May 24, 2018
90be0f9
improve print message
azai91 May 24, 2018
29b1a04
improve print
azai91 May 24, 2018
2493e83
fix new lines in output
azai91 May 24, 2018
3c23152
refactor print messages
azai91 May 24, 2018
4b4008e
fix typos
azai91 May 24, 2018
50f4f29
fix lint issues
azai91 May 24, 2018
54207a8
fix rebase
azai91 May 25, 2018
7a79e49
pass ndarray as ptr
azai91 May 25, 2018
12ed71c
store copy of ndarray in attrs obj
azai91 May 25, 2018
7435692
fix rem inits
azai91 May 25, 2018
75cb1e0
fix dispatch size
azai91 May 25, 2018
5396619
move print earlier
azai91 May 25, 2018
3c71dd8
use createmkldnnmem helper fun
azai91 May 25, 2018
897a687
fix lint
azai91 May 25, 2018
ca6afc8
refactor if else statement
azai91 May 27, 2018
6dbcd43
use buffer ndarray
azai91 May 27, 2018
6fcebae
fix spacing
azai91 May 27, 2018
08d3a74
fix refactor
azai91 May 27, 2018
66bc279
revert sum refactor
azai91 May 28, 2018
1df1cbb
use fallback compute
azai91 May 28, 2018
9fd5d1f
fix typo
azai91 May 28, 2018
d1126d6
fix lint
azai91 May 28, 2018
e0ce845
use fallbackcompute fn for act operator
azai91 May 28, 2018
7a28188
convert activation impl funcs to fxcompute std
azai91 May 28, 2018
15fef2b
remove unused var
azai91 May 28, 2018
55f4e88
move unused variable
azai91 May 28, 2018
3b97056
fix indent
azai91 May 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,59 +120,62 @@ void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
}

template<typename xpu>
void ActivationComputeImpl(const ActivationParam &param, const OpContext &ctx,
const TBlob &input, OpReqType req, const TBlob &output) {
void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
switch (param.act_type) {
case activation::kReLU:
ActivationForward<xpu, mshadow_op::relu, mshadow_op::relu_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kSigmoid:
ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kSoftReLU:
ActivationForward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kSoftSign:
ActivationForward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, input, req, output);
break;
ctx, inputs[0], req[0], outputs[0]);
break;
default:
LOG(FATAL) << "unknown activation type";
}
}

template<typename xpu>
void ActivationGradComputeImpl(const ActivationParam &param, const OpContext &ctx,
const TBlob &out_grad, const TBlob &out_data,
OpReqType req, const TBlob &output) {
void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
switch (param.act_type) {
case activation::kReLU:
ActivationBackward<xpu, mshadow_op::relu, mshadow_op::relu_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kSigmoid:
ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kSoftReLU:
ActivationBackward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kSoftSign:
ActivationBackward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, out_grad, out_data, req, output);
break;
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
default:
LOG(FATAL) << "unknown activation type";
}
Expand All @@ -186,8 +189,7 @@ void ActivationCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
ActivationComputeImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
ActivationComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
}

template<typename xpu>
Expand All @@ -196,16 +198,16 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
#else
CHECK_EQ(inputs.size(), 2U);
#endif
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
ActivationGradComputeImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], outputs[0]);
ActivationGradComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
}

} // namespace op
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (SupportMKLDNN(inputs[0])) {
Expand All @@ -71,7 +70,7 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
ActivationComputeImpl<cpu>(param, ctx, inputs[0].data(), req[0], outputs[0].data());
FallBackCompute(ActivationComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
}

void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
Expand All @@ -90,8 +89,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
ActivationGradComputeImpl<cpu>(param, ctx, inputs[0].data(), inputs[1].data(),
req[0], outputs[0].data());
FallBackCompute(ActivationGradComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif

Expand Down
16 changes: 11 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,19 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
auto input_mem = in_data.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_data, *input_mem);
auto out_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData(
fwd.fwd_pd.dst_primitive_desc());
fwd.SetNewMem(*input_mem, *out_mem);

NDArray in_buffer = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(),
req);
fwd.SetNewMem(*input_mem, *out_mem.second);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
stream->Submit();
}

Expand Down
Loading