-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-779]Add DLPack Transformation API #12047
Changes from 16 commits
822706e
8aac3da
ab6fa85
8c6e9d2
9fdfa7d
1142787
16df8d5
bfcffa2
f5c2552
98b5d11
7bdde8f
f225d27
afc1518
8b397fd
d48074a
58c5d87
72edbf8
ef8ffcd
afa1898
a4d3aee
493deb0
adf36ef
26db4d0
dec838d
850c3dc
fc99323
ffe60c6
cbb17c3
e56be1f
b1204bc
fe1387f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -104,6 +104,37 @@ class TBlob { | |
: dptr_(dptr), shape_(shape), type_flag_(type_flag) { | ||
SetDLTensor(dev_mask, dev_id); | ||
} | ||
/*! | ||
* \brief constructor that construct TBlob from DLTensor | ||
* \param DLTensor Object | ||
*/ | ||
explicit TBlob(const DLTensor &dltensor) : dptr_(dltensor.data), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add compactness check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specifically, TBlob only support compact tensors, need to check strides == null or the strides reflect a compact setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I will move the strides check from ndarray.cpp to tensor_blob.h. |
||
shape_(TShape(dltensor.shape, dltensor.shape + dltensor.ndim)), | ||
type_flag_(DLDataTypeTransform(dltensor.dtype)), dltensor_(dltensor) { | ||
// compactness check for DLTensor | ||
if (dltensor.strides != nullptr) { | ||
// check strides | ||
const int &ndim = dltensor.ndim; | ||
const int64_t *shape = dltensor.shape; | ||
const int64_t *strides = dltensor.strides; | ||
if (ndim >= 1) { | ||
bool err = false; | ||
if (strides[ndim - 1] != 1) { | ||
err = true; | ||
} else { | ||
for (int i = ndim - 2; i >= 0; --i) { | ||
if (strides[i] != shape[i + 1] * strides[i + 1]) { | ||
err = true; | ||
break; | ||
} | ||
} | ||
} | ||
if (err) { | ||
LOG(FATAL) << "Unsupported DLPack because MXNet only support compact tensor now"; | ||
} | ||
} | ||
} | ||
} | ||
/*! | ||
* \brief constructor from tensor | ||
* \param src source tensor | ||
|
@@ -336,14 +367,51 @@ class TBlob { | |
} | ||
} | ||
} | ||
static int DLDataTypeTransform(DLDataType dldata_type) { | ||
if (dldata_type.lanes != 1) { | ||
LOG(FATAL) << "Unsupported DLDataType whose lanes != 1"; | ||
} | ||
switch (dldata_type.code) { | ||
case kDLFloat: | ||
switch (dldata_type.bits) { | ||
case 16: | ||
return mshadow::kFloat16; | ||
case 32: | ||
return mshadow::kFloat32; | ||
case 64: | ||
return mshadow::kFloat64; | ||
} | ||
break; | ||
case kDLUInt: | ||
switch (dldata_type.bits) { | ||
case 8: | ||
return mshadow::kUint8; | ||
} | ||
break; | ||
case kDLInt: | ||
switch (dldata_type.bits) { | ||
case 8: | ||
return mshadow::kInt8; | ||
case 32: | ||
return mshadow::kInt32; | ||
case 64: | ||
return mshadow::kInt64; | ||
} | ||
break; | ||
} | ||
LOG(FATAL) << "Unknown DLDataType{" << dldata_type.code | ||
<< ", " << dldata_type.bits | ||
<< ", " << dldata_type.lanes << "}"; | ||
return mshadow::kFloat32; | ||
} | ||
|
||
inline void SetDLTensor(int dev_mask, int dev_id) { | ||
dltensor_.data = dptr_; | ||
dltensor_.ctx = DLContext{static_cast<DLDeviceType>(dev_mask), dev_id}; | ||
dltensor_.ndim = shape_.ndim(); | ||
dltensor_.dtype = DTypeTransform(type_flag_); | ||
dltensor_.shape = shape_.data(); | ||
dltensor_.strides = NULL; | ||
dltensor_.strides = nullptr; | ||
dltensor_.byte_offset = 0; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,21 +31,24 @@ | |
|
||
class NDArrayBase(object): | ||
"""Base data structure for ndarray""" | ||
__slots__ = ["handle", "writable"] | ||
__slots__ = ["handle", "writable", "dlpack"] | ||
# pylint: disable= no-member | ||
|
||
def __init__(self, handle, writable=True): | ||
def __init__(self, handle, writable=True, dlpack=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dlpack should not be part of the member, the PyCapsule manages itself There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dlpack in NDArray is PyCapsule which is the return value of
NDArray doesn't have the deleter function, so I made dlpack as a member of NDArray. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A better way is to keep NDArray's shared_ptr inside the manager_ctx itself, you can take a look at TVM's NDArray to DLManagedTesnor impl There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NDArray in MXNet and TVM are different. NDArray in TVM has the function Setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can create a new NDArray() that copies the original NDArray(which increases refcount) and put that as a context There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In your case, when a get deleted, b still holds a NDArrayDLManager, which is allocated by new, and that object still hold NDArray(which holds a shared_ptr), so the original resource won't be released There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to be careful to use shape from the same NDArray in your NDArrayDLManager There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. In the other case, from torch.utils import dlpack
a = torch.array([1,2,3])
pack = dlpack.to_dlpack(a)
b = mx.nd.from_dlpack(pack)
del a, pack When In my PR, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you copy the NDArray, they hold the same shared_ptr to the data, note that shared_ptr can be copied, and its ref counter is automatically managed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made a copy of NDArray as the member of NDArrayDLManager, and the copy increase the refcount. Which object will call the deleter function? In my case, when |
||
"""initialize a new NDArray | ||
|
||
Parameters | ||
---------- | ||
handle : NDArrayHandle | ||
NDArray handle of C API | ||
dlpack : PyCapsule (DLPack) | ||
DLPack Object | ||
""" | ||
if handle is not None: | ||
assert isinstance(handle, NDArrayHandle) | ||
self.handle = handle | ||
self.writable = writable | ||
self.dlpack = dlpack | ||
|
||
def __del__(self): | ||
check_call(_LIB.MXNDArrayFree(self.handle)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am less certain why we need the deleter function here, can they be directly handled in the python/cython side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to implement a deleter function in python, however the deleter function may be released by Python GC before calling the deleter function. See the test Code. It will raise segmentation fault.
The Python Frontend of MXNet both uses Python(ctypes) and Cython. It may be impossible to implement the deleter function in ctypes.
So the deleter function should be implemented in C++.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a look at destructor at apache/tvm#1573
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some subtlty here but they can never-the-less be implemented
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ever tried to write a python function as the destructor, but it can't pass CI.
Please see the MXNet CI result
All windows test_dlpack failed because the destructor written in Python is released before calling it.
PyTorch implemented the destructor using Python API in C++, and CuPy implemented it by cython, namely the code will be built by C++.
However, MXNet uses ctypes and cython. I couldn't find a better way to implement the destructor except writing it in MXNet C++ API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I knew the trick and tried it in my previous PR. But it failed in Windows Test.
Related CI
It seems that the CI of TVM doesn't have Windows Test so the CI is passed.
The reason is that the destructor will be released by Python GC before calling it.
And the GC release order are different between Linux and Windows.
In Linux, the destructor is called first, then the destructor is released. So it works.
However, In Windows, the destructor is released first before calling it, it doesn't work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is strange as destructor itself sits in the global scope and should be destructed after the dltensors(which have a local scope)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
In the test code, it works in Linux but failed in Windows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see two problems in your particular gist you paste.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
I found it works on Windows and Linux.
I have updated the PR.