Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed May 25, 2018
1 parent a056ae1 commit 990eaf5
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ static void InitNegPosArray(NDArray *arr, bool is_rand = false) {
}

using InitFunc = std::function<void (NDArray *arr, bool is_rand)>;
using VerifyFunc = std::function<void (const std::vector<NDArray *> &in_arrs, const NDArray &arr)>;

// Init arrays with the specified layout.
static void InitMKLDNNArray(NDArray *arr, const mkldnn::memory::primitive_desc &pd,
Expand Down Expand Up @@ -355,9 +356,9 @@ TEST(MKLDNN_NDArray, GetDataReorder) {
}

struct NDArrayAttrs {
NDArray arr;
NDArray *arr;
std::string desc;
NDArrayAttrs(NDArray arr, std::string desc) : arr(arr), desc(desc) {}
NDArrayAttrs(NDArray* arr, std::string desc) : arr(arr), desc(desc) {}
};

struct OpAttrs {
Expand Down Expand Up @@ -434,15 +435,15 @@ std::vector<NDArrayAttrs> GetTestInputArrays(InitFunc init_fn) {
// Type 1.
NDArray arr(shape, Context());
in_arrs.emplace_back(arr, "Normal NDArray");
init_fn(&in_arrs.back().arr, false);
init_fn(in_arrs.back().arr, false);
for (auto pd : pds) {
if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
continue;

// Type 2, 3.
arr = NDArray(shape, Context());
in_arrs.emplace_back(arr, "MKLDNN NDArray");
InitMKLDNNArray(&in_arrs.back().arr, pd, init_fn);
InitMKLDNNArray(in_arrs.back().arr, pd, init_fn);

// Type 4, 5, 6.
arr = NDArray(shape, Context());
Expand All @@ -458,12 +459,12 @@ TEST(MKLDNN_NDArray, GetTestInputArrays) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray);
int mkldnn_count = 0, mkldnn_view_count = 0;
for (auto arr : in_arrs) {
if (arr.arr.IsView() && arr.arr.IsMKLDNNData()) {
if (arr.arr->IsView() && arr.arr->IsMKLDNNData()) {
mkldnn_view_count++;
continue;
}

if (arr.arr.IsMKLDNNData()) {
if (arr.arr->IsMKLDNNData()) {
mkldnn_count++;
continue;
}
Expand Down Expand Up @@ -498,7 +499,7 @@ std::vector<NDArrayAttrs> GetTestOutputArrays(const TShape &shape,
// Type 1.
NDArray arr(shape, Context());
in_arrs.emplace_back(arr, "Normal NDArray");
init_fn(&in_arrs.back().arr, true);
init_fn(in_arrs.back().arr, true);

// Type 4.
TShape tmp_shape = shape;
Expand Down Expand Up @@ -553,7 +554,7 @@ std::vector<NDArrayAttrs> GetTestOutputArrays(const TShape &shape,
// Type 2, 3.
arr = NDArray(shape, Context());
in_arrs.emplace_back(arr, "MKLDNN NDArray");
InitMKLDNNArray(&in_arrs.back().arr, pd, init_fn, true);
InitMKLDNNArray(in_arrs.back().arr, pd, init_fn, true);

// Type 8, 9.
// Get a reused version.
Expand All @@ -567,8 +568,6 @@ std::vector<NDArrayAttrs> GetTestOutputArrays(const TShape &shape,
return in_arrs;
}

using VerifyFunc = std::function<void (const std::vector<NDArray *> &in_arrs, const NDArray &arr)>;

void VerifyCopyResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr) {
NDArray tmp1 = in_arrs[0]->Reorder2Default();
NDArray tmp2 = arr.Reorder2Default();
Expand All @@ -592,19 +591,6 @@ void VerifyActResult(const NDArray &in_arr, const NDArray &arr) {
}
}

void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
TShape t1 = arr1.arr.shape();
TShape t2 = arr2.arr.shape();

printf("Verifying: %s (", arr1.desc.c_str());
for (size_t i = 0; i < t1.ndim(); i++)
printf("%ld, ", t1[i]);
printf(") with %s (", arr2.desc.c_str());
for (size_t i = 0; i < t2.ndim(); i++)
printf("%ld, ", t2[i]);
printf(")\n");
}

void VerifySumResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr) {
NDArray in1 = in_arrs[0]->Reorder2Default();
NDArray in2 = in_arrs[1]->Reorder2Default();
Expand All @@ -619,22 +605,35 @@ void VerifySumResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr)
EXPECT_EQ(d1[i] + d2[i], o[i]);
}

void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
TShape t1 = arr1.arr->shape();
TShape t2 = arr2.arr->shape();

printf("Verifying: %s (", arr1.desc.c_str());
for (size_t i = 0; i < t1.ndim(); i++)
printf("%ld, ", t1[i]);
printf(") with %s (", arr2.desc.c_str());
for (size_t i = 0; i < t2.ndim(); i++)
printf("%ld, ", t2[i]);
printf(")\n");
}

TEST(MKLDNN_NDArray, CopyFrom) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArray> in_arrs = GetTestInputArrays();
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray);
for (auto in_arr : in_arrs) {
std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr.shape(), pds);
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr->shape(), pds, InitDefaultArray);
for (auto out_arr : out_arrs) {
if (in_arr.IsMKLDNNData() && in_arr.IsView())
in_arr = in_arr.Reorder2Default();
const mkldnn::memory *mem = in_arr.GetMKLDNNData();
out_arr.CopyFrom(*mem);
if (in_arr.arr->IsMKLDNNData() && in_arr.arr->IsView())
*in_arr.arr = in_arr.arr->Reorder2Default();
const mkldnn::memory *mem = in_arr.arr->GetMKLDNNData();
out_arr.arr->CopyFrom(*mem);
MKLDNNStream::Get()->Submit();
std::vector<NDArray *> inputs(1);
inputs[0] = &in_arr;
VerifyCopyResult(inputs, out_arr);
inputs[0] = in_arr.arr;
VerifyCopyResult(inputs, *out_arr.arr);
}
}
}
Expand All @@ -651,16 +650,16 @@ void TestUnaryOp(const OpAttrs &attrs, InitFunc init_fn, VerifyFunc verify_fn) {
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(init_fn);
for (auto in_arr : in_arrs) {
for (auto dispatch : dispatches) {
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds, init_fn);
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr->shape(), pds, init_fn);
for (auto out_arr : out_arrs) {
req[0] = kWriteTo;
inputs[0] = &in_arr.arr;
outputs[0] = &out_arr.arr;
inputs[0] = in_arr.arr;
outputs[0] = out_arr.arr;
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
out_arr.WaitToRead();
out_arr.arr->WaitToRead();
PrintVerifyMsg(in_arr, out_arr);
verify_fn(in_arr, out_arr.arr);
verify_fn(inputs, *outputs[0]);
}
}
}
Expand All @@ -669,19 +668,20 @@ void TestUnaryOp(const OpAttrs &attrs, InitFunc init_fn, VerifyFunc verify_fn) {
in_arrs = GetTestInputArrays(init_fn);
for (auto arr : in_arrs) {
// If the array is a view, we shouldn't write data to it.
if (arr.arr.IsView())
if (arr.arr->IsView())
continue;

NDArrayAttrs orig(arr.arr.Copy(arr.arr.ctx()), "InPlace Copy");
NDArray *tmp = arr.arr->Copy(arr.arr->ctx();
NDArrayAttrs orig(tmp, "InPlace Copy");
req[0] = kWriteInplace;
inputs[0] = &arr.arr;
outputs[0] = &arr.arr;
inputs[0] = arr.arr;
outputs[0] = arr.arr;
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
dispatch, mxnet::OpStatePtr());
arr.WaitToRead();
inputs[0] = &orig;
arr.arr->WaitToRead();
inputs[0] = tmp;
PrintVerifyMsg(orig, arr);
verify_fn(inputs, arr);
verify_fn(inputs, *outputs[0]);
}
}
}
Expand All @@ -695,42 +695,42 @@ void TestBinaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
TestArrayShapes tas = GetTestArrayShapes();
std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;

std::vector<NDArray> in_arrs = GetTestInputArrays();
std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(InitDefaultArray);
for (auto in_arr1 : in_arrs) {
for (auto dispatch : dispatches) {
std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr1.shape(), pds);
std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr1.arr->shape(), pds, InitDefaultArray);
for (auto out_arr : out_arrs) {
req[0] = kWriteTo;
inputs[0] = &in_arr1;
inputs[1] = &in_arr1;
outputs[0] = &out_arr;
inputs[0] = in_arr1.arr;
inputs[1] = in_arr1.arr;
outputs[0] = out_arr.arr;
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
outputs, req, dispatch, mxnet::OpStatePtr());
out_arr.WaitToRead();
verify_fn(inputs, out_arr);
out_arr.arr->WaitToRead();
verify_fn(inputs, *out_arr.arr);
}
}
}

for (auto dispatch : dispatches) {
in_arrs = GetTestInputArrays();
in_arrs = GetTestInputArrays(InitDefaultArray);
for (auto arr : in_arrs) {
// If the array is a view, we shouldn't write data to it.
if (arr.IsView())
if (arr.arr->IsView())
continue;

NDArray orig = arr.Copy(arr.ctx());
NDArray orig = arr.arr->Copy(arr.arr->ctx());
req[0] = kWriteInplace;
inputs[0] = &arr;
inputs[1] = &arr;
outputs[0] = &arr;
inputs[0] = arr.arr;
inputs[1] = arr.arr;
outputs[0] = arr.arr;
Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
dispatch, mxnet::OpStatePtr());
arr.WaitToRead();
std::vector<NDArray *> orig_inputs(2);
arr.arr->WaitToRead();
std::vector<NDArray*> orig_inputs(2);
orig_inputs[0] = &orig;
orig_inputs[1] = &orig;
verify_fn(orig_inputs, arr);
verify_fn(orig_inputs, *arr.arr);
}
}
}
Expand Down

0 comments on commit 990eaf5

Please sign in to comment.