Skip to content

Commit

Permalink
Test/mkl dnn act (apache#11026)
Browse files Browse the repository at this point in the history
* add mkl act unit test

* fix operator name

* use custom ndarray init

* func missing param

* add init fn for act operator test

* remove extra white space

* fix fetch relu operator

* fix get  relu operator name

* add assert abs in verify fn

* remove unused operator

* cast blob ptr to float

* use parsed param

* use attr_parser

* fix header order

* update test fn name

* use relu fn

* add kFComputeEx dispatch

* init posneg mklarray

* fix generating rnd pos neg ints

* output arrays are rnd generated

* test that getinputarrays creates view and mkldnn arrays

* add more output types

* fix typo

* fix gettestput test

* create arrattr struct to display arr info

* refactor initarray

* print arr description in verify fn

* use long int string interpolation

* fix alias params

* iterate over dims

* print c_str

* print output info

* improve print message

* improve print

* fix new lines in output

* refactor print messages

* fix typos

* fix lint issues

* fix rebase

* pass ndarray as ptr

* store copy of ndarray in attrs obj

* fix rem inits

* fix dispatch size

* move print earlier

* use createmkldnnmem helper fun

* fix lint

* refactor if else statement

* use buffer ndarray

* fix spacing

* fix refactor

* revert sum refactor

* use fallback compute

* fix typo

* fix lint

* use fallbackcompute fn for act operator

* convert activation impl funcs to fxcompute std

* remove unused var

* move unused variable

* fix indent
  • Loading branch information
azai91 authored and piiswrong committed May 30, 2018
1 parent 9514a1e commit 92286c9
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 120 deletions.
44 changes: 23 additions & 21 deletions src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,59 +120,62 @@ void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
}

template<typename xpu>
void ActivationComputeImpl(const ActivationParam &param, const OpContext &ctx,
const TBlob &input, OpReqType req, const TBlob &output) {
void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
switch (param.act_type) {
case activation::kReLU:
ActivationForward<xpu, mshadow_op::relu, mshadow_op::relu_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kSigmoid:
ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kSoftReLU:
ActivationForward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, input, req, output);
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kSoftSign:
ActivationForward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, input, req, output);
break;
ctx, inputs[0], req[0], outputs[0]);
break;
default:
LOG(FATAL) << "unknown activation type";
}
}

template<typename xpu>
void ActivationGradComputeImpl(const ActivationParam &param, const OpContext &ctx,
const TBlob &out_grad, const TBlob &out_data,
OpReqType req, const TBlob &output) {
void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
switch (param.act_type) {
case activation::kReLU:
ActivationBackward<xpu, mshadow_op::relu, mshadow_op::relu_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kSigmoid:
ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kSoftReLU:
ActivationBackward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, out_grad, out_data, req, output);
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kSoftSign:
ActivationBackward<xpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, out_grad, out_data, req, output);
break;
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
default:
LOG(FATAL) << "unknown activation type";
}
Expand All @@ -186,8 +189,7 @@ void ActivationCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
ActivationComputeImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
ActivationComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
}

template<typename xpu>
Expand All @@ -196,16 +198,16 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
#if (MXNET_USE_CUDNN == 1 || MXNET_USE_MKLDNN == 1)
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
bool relu = param.act_type == activation::kReLU;
CHECK_EQ(inputs.size(), relu ? 2U : 3U);
#else
CHECK_EQ(inputs.size(), 2U);
#endif
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
ActivationGradComputeImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], outputs[0]);
ActivationGradComputeImpl<xpu>(attrs, ctx, inputs, req, outputs);
}

} // namespace op
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
if (SupportMKLDNN(inputs[0])) {
Expand All @@ -71,7 +70,7 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
ActivationComputeImpl<cpu>(param, ctx, inputs[0].data(), req[0], outputs[0].data());
FallBackCompute(ActivationComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
}

void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
Expand All @@ -90,8 +89,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
ActivationGradComputeImpl<cpu>(param, ctx, inputs[0].data(), inputs[1].data(),
req[0], outputs[0].data());
FallBackCompute(ActivationGradComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif

Expand Down
16 changes: 11 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,19 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data) {
const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
auto input_mem = in_data.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_data, *input_mem);
auto out_mem = const_cast<NDArray &>(out_data).CreateMKLDNNData(
fwd.fwd_pd.dst_primitive_desc());
fwd.SetNewMem(*input_mem, *out_mem);

NDArray in_buffer = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData())
in_buffer = in_data.Reorder2Default();

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(),
req);
fwd.SetNewMem(*input_mem, *out_mem.second);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
stream->Submit();
}

Expand Down
Loading

0 comments on commit 92286c9

Please sign in to comment.