Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add more output types
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed May 24, 2018
1 parent 62c5ae7 commit e3d553b
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions tests/cpp/operator/mkldnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ TEST(MKLDNN_NDArray, GetTestInputArrays) {
std::vector<NDArray> in_arrs = GetTestInputArrays(InitDefaultArray);
int mkldnn_count = 0, mkldnn_view_count = 0;
for (auto in_arr : in_arrs) {

if (in_arr.isView() && in_arr.IsMKLDNNData()) {
mkldnn_view_count++;
continue;
Expand All @@ -465,16 +465,30 @@ TEST(MKLDNN_NDArray, GetTestInputArrays) {
* pass them to all operators.
* In the inference mode, the MKLDNN memory in the weight array will be
* reordered to 5 dimensions.
* 4. Reused NDArray (this is created by the MXNet executor). This type of
* 4. Reshaped/sliced NDArray
* 5. Reused NDArray (this is created by the MXNet executor). This type of
* NDArrays can only be used as output arrays.
* 6. Reused NDArray converted from an array with a different data type.
* 7. Reused reshaped/sliced NDArray.
* 8. Reused NDArray with MKLDNN layout.
* 9. Reused NDArray with MKLDNN layout of different dimensions.
*/
std::vector<NDArray> GetTestOutputArrays(const TShape &shape,
const std::vector<mkldnn::memory::primitive_desc> &pds,
const InitFunc init_fn) {
std::vector<NDArray> in_arrs;
// Type 1.
in_arrs.emplace_back(shape, Context());
init_fn(&in_arrs.back(), true);

// Type 4.
TShape tmp_shape = shape;
tmp_shape[0] = shape[0] * 2;
NDArray arr0(tmp_shape, Context());
init_fn(&arr0, true);
in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1));

// Type 5.
// Get a reused version.
nnvm::TShape s(1);
s[0] = shape.Size();
Expand All @@ -483,17 +497,34 @@ std::vector<NDArray> GetTestOutputArrays(const TShape &shape,
init_fn(&arr, true);
in_arrs.emplace_back(arr);

// Type 6.
s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag);
NDArray arr2(s, Context(), true, mshadow::kUint8);
arr2 = arr2.AsArray(shape, mshadow::default_type_flag);
init_fn(&arr2, true);
in_arrs.emplace_back(arr2);

// Type 7
s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag) * 2;
NDArray arr3(s, Context(), true, mshadow::kUint8);
tmp_shape[0] = shape[0] * 2;
arr3 = arr3.AsArray(tmp_shape, mshadow::default_type_flag);
init_fn(&arr3, true);
in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1));

for (auto pd : pds) {
if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
continue;

// Type 2, 3.
in_arrs.emplace_back(shape, Context());
InitMKLDNNArray(&in_arrs.back(), pd, init_fn, true);

// Type 8, 9.
// Get a reused version.
nnvm::TShape s(1);
s[0] = shape.Size();
arr = NDArray(s, Context());
NDArray arr = NDArray(s, Context());
arr = arr.AsArray(shape, arr.dtype());
InitMKLDNNArray(&arr, pd, init_fn, true);
in_arrs.emplace_back(arr);
Expand Down

0 comments on commit e3d553b

Please sign in to comment.