Skip to content

Commit

Permalink
add batch norm test (apache#13625)
Browse files Browse the repository at this point in the history
* add batch norm test

* fix formatting

* use out_arr as input

* fix typo

* remove const

* use ptr

* eval ptr
  • Loading branch information
azai91 authored and anirudh2290 committed Dec 13, 2018
1 parent 439f167 commit b45e127
Showing 1 changed file with 150 additions and 1 deletion.
151 changes: 150 additions & 1 deletion tests/cpp/operator/mkldnn_operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, in
return attrs;
}

OpAttrs GetBNOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("BatchNorm");
attrs.num_inputs = 5;
attrs.num_outputs = 3;
attrs.accept_dims.insert(4);
attrs.requests.insert(OpReqType::kWriteTo);
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.input_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN;
attrs.output_types = ArrayTypes::Normal |
ArrayTypes::MKLDNN;
return attrs;
}

OpAttrs GetBNBackwardOp() {
OpAttrs attrs;
attrs.attrs.op = Op::Get("_backward_BatchNorm");
attrs.num_inputs = 8;
attrs.num_outputs = 3;
attrs.attrs.op->attr_parser(&attrs.attrs);
attrs.requests.insert(OpReqType::kWriteTo);
return attrs;
}

void AssertEqual(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs,
float rtol = 1e-5, float atol = 1e-8) {
Expand Down Expand Up @@ -710,7 +735,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {

// If the array is a view, we shouldn't write data to it.
if (in_arr.arr.IsView())
continue;
continue;

NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < forward_attrs.num_inputs; i++)
Expand All @@ -735,6 +760,124 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
}
}


void TestOpExBNBackward(const OpAttrs &forward_attrs,
const OpAttrs &backwards_attrs,
const OpReqType &req,
const std::vector<NDArray*> &inputs,
const std::vector<NDArray*> &outputs,
const NDArrayAttrs &in_arr,
NDArrayAttrs* out_arr) {
std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);

std::vector<NDArray> backwards_buffer(backwards_attrs.num_outputs);
std::vector<NDArray> backwards_buffer2(backwards_attrs.num_outputs);

std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
std::vector<OpReqType> backwards_req(backwards_attrs.num_outputs);

if (req == kWriteTo) {
backwards_input[0] = &(out_arr->arr); // output grad
backwards_input[1] = outputs[1]; // mean
backwards_input[2] = outputs[2]; // var
backwards_input[3] = inputs[0]; // data
backwards_input[4] = inputs[1]; // gamma
backwards_input[5] = inputs[2]; // beta
backwards_input[6] = inputs[3]; // moving mean
backwards_input[7] = inputs[4]; // moving var

for (size_t i = 0; i < backwards_attrs.num_outputs; i++) {
auto tmp_output = in_arr.arr;
backwards_buffer.emplace_back(tmp_output.Copy(Context()));
backwards_buffer2.emplace_back(tmp_output.Copy(Context()));
backwards_outputs[i] = &backwards_buffer.back();
backwards_ex_outputs[i] = &backwards_buffer2.back();
Engine::Get()->WaitForAll();
backwards_req[i] = kWriteTo;
}

std::cout << "Backwards: ";
PrintVerifyMsg(*out_arr, in_arr);
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
backwards_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
backwards_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(backwards_outputs, backwards_ex_outputs);
}
}

// compares output of fcompute with fcomputex
void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
std::vector<NDArray*> inputs(forward_attrs.num_inputs);
std::vector<NDArray*> inputs2(forward_attrs.num_inputs);
std::vector<NDArray> inputs_buffer(forward_attrs.num_inputs);
std::vector<NDArray> inputs2_buffer(forward_attrs.num_inputs);
std::vector<NDArray*> outputs(forward_attrs.num_outputs);
std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
std::vector<OpReqType> req(forward_attrs.num_outputs);

TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, false);
std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);

if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
for (int i1 = 0; i1 < in_arrs.size(); i1++) {
auto in_arr = in_arrs[i1];

CHECK_NE(forward_attrs.accept_dims.size(), 0);
if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
forward_attrs.accept_dims.end())
continue;
for (int i = 0; i < forward_attrs.num_outputs; i++) {
out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
ex_out_arrs[i] =
GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
}
for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
inputs_buffer.clear();
inputs2_buffer.clear();

for (int i = 0; i < forward_attrs.num_inputs; i++) {
inputs_buffer.emplace_back(in_arr.arr.Copy(Context()));
inputs2_buffer.emplace_back(in_arr.arr.Copy(Context()));
Engine::Get()->WaitForAll();
inputs[i] = &inputs_buffer.back();
inputs2[i] = &inputs2_buffer.back();
}
for (int i = 0; i < forward_attrs.num_outputs; i++) {
req[i] = kWriteTo;
outputs[i] = &out_arrs[i][output_i].arr;
ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
}
Imperative::Get()->set_is_training(true);

PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs, outputs, req,
DispatchMode::kFCompute, mxnet::OpStatePtr());
Imperative::Get()->InvokeOp(
Context(), forward_attrs.attrs, inputs2, ex_outputs, req,
DispatchMode::kFComputeEx, mxnet::OpStatePtr());
Engine::Get()->WaitForAll();
AssertEqual(outputs, ex_outputs);

if (!backwards_attrs.requests.empty()) {
TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
inputs, outputs, in_arr, &out_arrs[0][output_i]);
}
}
}
}
}

// Computes second dimension of FC weight matrix based on input shape
uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
uint32_t dim = 1;
Expand Down Expand Up @@ -1204,4 +1347,10 @@ TEST(IMPERATIVE, DeconvOp) {
}
}

TEST(IMPERATIVE, BNOp) {
OpAttrs forward_attrs = GetBNOp();
OpAttrs backwards_attrs = GetBNBackwardOp();
TestOpExBN(forward_attrs, backwards_attrs);
}

#endif

0 comments on commit b45e127

Please sign in to comment.