Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions ffi/include/tvm/ffi/container/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@
namespace tvm {
namespace ffi {

/*!
* \brief Check if the device uses direct address, where address of data indicate alignment.
* \param device The input device.
* \return True if the device uses direct address, false otherwise.
*/
inline bool IsDirectAddressDevice(const DLDevice& device) {
return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged ||
device.device_type == kDLROCM || device.device_type == kDLROCMHost;
}

/*!
* \brief check if a DLTensor is contiguous.
* \param arr The input DLTensor.
Expand Down Expand Up @@ -67,11 +77,7 @@ inline bool IsContiguous(const DLTensor& arr) {
* \return True if the data is aligned to the given alignment, false otherwise.
*/
inline bool IsAligned(const DLTensor& arr, size_t alignment) {
// whether the device uses direct address mapping instead of indirect buffer
bool direct_address = arr.device.device_type <= kDLCUDAHost ||
arr.device.device_type == kDLCUDAManaged ||
arr.device.device_type == kDLROCM || arr.device.device_type == kDLROCMHost;
if (direct_address) {
if (IsDirectAddressDevice(arr.device)) {
return (reinterpret_cast<size_t>(static_cast<char*>(arr.data) + arr.byte_offset) % alignment ==
0);
} else {
Expand Down Expand Up @@ -278,6 +284,12 @@ class Tensor : public ObjectRef {
* \return True if the Tensor is contiguous, false otherwise.
*/
bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); }
/*!
* \brief Check if the Tensor data is aligned to the given alignment.
* \param alignment The alignment to check.
* \return True if the Tensor data is aligned to the given alignment, false otherwise.
*/
bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); }
/*!
* \brief Create a Tensor from a NDAllocator.
* \param alloc The NDAllocator.
Expand Down
4 changes: 1 addition & 3 deletions ffi/python/tvm_ffi/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@ def convert(value: Any) -> Any:
elif value is None:
return None
elif hasattr(value, "__dlpack__"):
return core.from_dlpack(
value, required_alignment=core.__dlpack_auto_import_required_alignment__
)
return core.from_dlpack(value)
elif isinstance(value, Exception):
return core._convert_to_ffi_error(value)
else:
Expand Down
5 changes: 2 additions & 3 deletions ffi/python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
out[i].v_ptr = (<Object>arg).chandle
elif torch is not None and isinstance(arg, torch.Tensor):
is_cuda = arg.is_cuda
arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg),
required_alignment=__dlpack_auto_import_required_alignment__)
arg = from_dlpack(torch.utils.dlpack.to_dlpack(arg))
out[i].type_index = kTVMFFITensor
out[i].v_ptr = (<Tensor>arg).chandle
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
Expand All @@ -123,7 +122,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
temp_args.append(arg)
elif hasattr(arg, "__dlpack__"):
arg = from_dlpack(arg, required_alignment=__dlpack_auto_import_required_alignment__)
arg = from_dlpack(arg)
out[i].type_index = kTVMFFITensor
out[i].v_ptr = (<Tensor>arg).chandle
temp_args.append(arg)
Expand Down
43 changes: 21 additions & 22 deletions ffi/python/tvm_ffi/cython/tensor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.

__dlpack_version__ = (1, 1)
__dlpack_auto_import_required_alignment__ = 8
_CLASS_TENSOR = None


Expand Down Expand Up @@ -45,13 +44,13 @@ cdef void _c_dlpack_versioned_deleter(object pycaps):


cdef inline int _from_dlpack(
object dltensor, int required_alignment,
int required_contiguous, TVMFFIObjectHandle* out
object dltensor, int require_alignment,
int require_contiguous, TVMFFIObjectHandle* out
) except -1:
cdef DLManagedTensor* ptr
cdef int c_api_ret_code
cdef int c_req_alignment = required_alignment
cdef int c_req_contiguous = required_contiguous
cdef int c_req_alignment = require_alignment
cdef int c_req_contiguous = require_contiguous
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
with nogil:
Expand All @@ -66,13 +65,13 @@ cdef inline int _from_dlpack(


cdef inline int _from_dlpack_versioned(
object dltensor, int required_alignment,
int required_contiguous, TVMFFIObjectHandle* out
object dltensor, int require_alignment,
int require_contiguous, TVMFFIObjectHandle* out
) except -1:
cdef DLManagedTensorVersioned* ptr
cdef int c_api_ret_code
cdef int c_req_alignment = required_alignment
cdef int c_req_contiguous = required_contiguous
cdef int c_req_alignment = require_alignment
cdef int c_req_contiguous = require_contiguous
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor_versioned):
ptr = <DLManagedTensorVersioned*>pycapsule.PyCapsule_GetPointer(
dltensor, _c_str_dltensor_versioned)
Expand All @@ -87,7 +86,7 @@ cdef inline int _from_dlpack_versioned(
raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be consumed once")


def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True):
def from_dlpack(ext_tensor, *, require_alignment=0, require_contiguous=False):
"""
Convert an external tensor to an Tensor.

Expand All @@ -96,10 +95,10 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True):
ext_tensor : object
The external tensor to convert.

required_alignment : int
require_alignment : int
The minimum required alignment to check for the tensor.

required_contiguous : bool
require_contiguous : bool
Whether to check for contiguous memory.

Returns
Expand All @@ -116,38 +115,38 @@ def from_dlpack(ext_tensor, *, required_alignment=8, required_contiguous=True):
if favor_legacy_dlpack:
_from_dlpack(
ext_tensor.__dlpack__(),
required_alignment,
required_contiguous,
require_alignment,
require_contiguous,
&chandle
)
else:
try:
_from_dlpack_versioned(
ext_tensor.__dlpack__(max_version=__dlpack_version__),
required_alignment,
required_contiguous,
require_alignment,
require_contiguous,
&chandle
)
except TypeError:
_from_dlpack(
ext_tensor.__dlpack__(),
required_alignment,
required_contiguous,
require_alignment,
require_contiguous,
&chandle
)
else:
if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned):
_from_dlpack_versioned(
ext_tensor,
required_alignment,
required_contiguous,
require_alignment,
require_contiguous,
&chandle
)
elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor):
_from_dlpack(
ext_tensor,
required_alignment,
required_contiguous,
require_alignment,
require_contiguous,
&chandle
)
else:
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/runtime/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,18 @@ def from_dlpack(ext_tensor):
ext_tensor : object
The external tensor to convert.

required_alignment : int
require_alignment : int
The minimum required alignment to check for the tensor.

required_contiguous : bool
require_contiguous : bool
Whether to check for contiguous memory.
"""
# TODO(tvm-team): change to require_alignment=0 and require_contiguous=False
# once we update the compiler generated code to guard against misaligned access.
return tvm_ffi.from_dlpack(
ext_tensor,
required_alignment=64,
required_contiguous=True,
require_alignment=64,
require_contiguous=True,
)


Expand Down
4 changes: 2 additions & 2 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,8 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
// Check data_alignment
CHECK(source_buffer->data_alignment % buffer->data_alignment == 0)
<< "Trying to match buffer to another one with lower alignment requirement "
<< " required_alignment=" << buffer->data_alignment
<< ", provided_alignment=" << source_buffer->data_alignment;
<< " required alignment=" << buffer->data_alignment
<< ", provided alignment=" << source_buffer->data_alignment;

// Check BufferType. AutoBroadcast is not allowed for now.
CHECK(buffer->buffer_type == BufferType::kDefault &&
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/arg_binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st
<< "Argument " << arg_name << " Buffer bind data type mismatch";
if (value->data_alignment % arg->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
<< " required_alignment=" << arg->data_alignment
<< ", provided_alignment=" << value->data_alignment;
<< " required alignment=" << arg->data_alignment
<< ", provided alignment=" << value->data_alignment;
}

if (value->elem_offset.defined()) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/lower_match_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class MatchBufferLower : public StmtExprMutator {
// Step.1.2. Check data alignment
if (source_buffer->data_alignment % buffer->data_alignment != 0) {
LOG(WARNING) << "Trying to bind buffer to another one with lower alignment requirement "
<< " required_alignment=" << buffer->data_alignment
<< ", provided_alignment=" << source_buffer->data_alignment;
<< " required alignment=" << buffer->data_alignment
<< ", provided alignment=" << source_buffer->data_alignment;
}
if (is_zero(buffer->elem_offset)) {
ICHECK(is_zero(source_buffer->elem_offset))
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_op_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def main(A: R.Tensor, axis: R.Prim("int64")):
expected_strides = [1, 4]
# use transpose to make strides non-compact
x = np.zeros([4, 4], "int32").T
y = tvm_ffi.from_dlpack(x, required_alignment=4, required_contiguous=False)
y = tvm_ffi.from_dlpack(x, require_alignment=4, require_contiguous=False)
res = [vm["main"](y, i) for i, _ in enumerate(view_shape)]
tvm.ir.assert_structural_equal(res, expected_strides)

Expand Down
Loading