Skip to content

Commit

Permalink
gpu access of ndarray (apache#7496)
Browse files Browse the repository at this point in the history
* gpu access of ndarray

* gpu access from C++ api

* gpu access fix

* Update c_api.cc

* Update c_api.cc
  • Loading branch information
dtmoodie authored and piiswrong committed Aug 26, 2017
1 parent 2e6ef8c commit 1e48e12
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 8 deletions.
1 change: 0 additions & 1 deletion cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ inline int NDArray::GetDType() const {

inline const mx_float *NDArray::GetData() const {
void *ret;
CHECK_NE(GetContext().GetDeviceType(), DeviceType::kGPU);
MXNDArrayGetData(blob_ptr_->handle_, &ret);
if (GetDType() != 0) {
return NULL;
Expand Down
8 changes: 1 addition & 7 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,7 @@ int MXNDArrayGetData(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
CHECK(arr->ctx().dev_mask() == cpu::kDevMask)
<< "MXNDArrayGetData can only be called for NDArray on CPU";
const TBlob &b = arr->data();
CHECK(b.CheckContiguous());
MSHADOW_REAL_TYPE_SWITCH(arr->dtype(), DType, {
*out_pdata = b.FlatTo2D<cpu, DType>().dptr_;
});
*out_pdata = arr->data().dptr_;
} else {
*out_pdata = nullptr;
}
Expand Down

0 comments on commit 1e48e12

Please sign in to comment.