From 1e48e1238609b94cae2af3d0d72b33882b9c5a24 Mon Sep 17 00:00:00 2001 From: dtmoodie Date: Sat, 26 Aug 2017 14:51:34 -0400 Subject: [PATCH] gpu access of ndarray (#7496) * gpu access of ndarray * gpu access from C++ api * gpu access fix * Update c_api.cc * Update c_api.cc --- cpp-package/include/mxnet-cpp/ndarray.hpp | 1 - src/c_api/c_api.cc | 8 +------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 5ed04a547b85..6bf26432359b 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -359,7 +359,6 @@ inline int NDArray::GetDType() const { inline const mx_float *NDArray::GetData() const { void *ret; - CHECK_NE(GetContext().GetDeviceType(), DeviceType::kGPU); MXNDArrayGetData(blob_ptr_->handle_, &ret); if (GetDType() != 0) { return NULL; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 0fe3fe3e302e..088e208c9cdc 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -437,13 +437,7 @@ int MXNDArrayGetData(NDArrayHandle handle, API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { - CHECK(arr->ctx().dev_mask() == cpu::kDevMask) - << "MXNDArrayGetData can only be called for NDArray on CPU"; - const TBlob &b = arr->data(); - CHECK(b.CheckContiguous()); - MSHADOW_REAL_TYPE_SWITCH(arr->dtype(), DType, { - *out_pdata = b.FlatTo2D().dptr_; - }); + *out_pdata = arr->data().dptr_; } else { *out_pdata = nullptr; }