diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 6fda8c37b416..e243eb71c477 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -678,10 +678,7 @@ class NDArray { */ NDArray Reorder2Default() const; - void InvalidateMKLDNNData() { - // Removing mkl_mem_ means the NDArray will store data in the default format. - ptr_->mkl_mem_ = nullptr; - } + void InvalidateMKLDNNData(); /* * This function is used inside operators to reshape an array. diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 82de0949ccc3..a28a907a9410 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -620,6 +620,12 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const { } } +void NDArray::InvalidateMKLDNNData() { + // Removing mkl_mem_ means the NDArray will store data in the default format. + if (ptr_->mkl_mem_ && ptr_->mkl_mem_->IsMKLDNN()) + ptr_->mkl_mem_ = nullptr; +} + void NDArray::CopyFrom(const mkldnn::memory &mem) { CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized"; if (ptr_->mkl_mem_ && ptr_->mkl_mem_->GetRaw() == &mem)