Skip to content

Commit 8b1f792

Browse files
committed
[PYTHON][FFI] Cythonize NDArray.copyto
1 parent 5d328c5 commit 8b1f792

File tree

3 files changed

+16
-7
lines changed

3 files changed

+16
-7
lines changed

python/tvm/_ffi/_ctypes/ndarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ def __del__(self):
8585
def _tvm_handle(self):
8686
return ctypes.cast(self.handle, ctypes.c_void_p).value
8787

88+
def _copyto(self, target_nd):
89+
"""Internal function that implements copy to target ndarray."""
90+
check_call(_LIB.TVMArrayCopyFromTo(
91+
self.handle, target_nd.handle, None))
92+
return target_nd
93+
8894
def to_dlpack(self):
8995
"""Produce an array from a DLPack Tensor without copying memory
9096

python/tvm/_ffi/_cython/ndarray.pxi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ cdef class NDArrayBase:
7676
if self.c_is_view == 0:
7777
CALL(TVMArrayFree(self.chandle))
7878

79+
def _copyto(self, target_nd):
80+
"""Internal function that implements copy to target ndarray."""
81+
CALL(TVMArrayCopyFromTo(self.chandle, (<NDArrayBase>target_nd).chandle, NULL))
82+
return target_nd
83+
7984
def to_dlpack(self):
8085
"""Produce an array from a DLPack Tensor without copying memory
8186

python/tvm/_ffi/ndarray.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,12 @@ def copyto(self, target):
294294
target : NDArray
295295
The target array to be copied, must have same shape as this array.
296296
"""
297-
if isinstance(target, TVMContext):
298-
target = empty(self.shape, self.dtype, target)
299297
if isinstance(target, NDArrayBase):
300-
check_call(_LIB.TVMArrayCopyFromTo(
301-
self.handle, target.handle, None))
302-
else:
303-
raise ValueError("Unsupported target type %s" % str(type(target)))
304-
return target
298+
return self._copyto(target)
299+
elif isinstance(target, TVMContext):
300+
res = empty(self.shape, self.dtype, target)
301+
return self._copyto(res)
302+
raise ValueError("Unsupported target type %s" % str(type(target)))
305303

306304

307305
def free_extension_handle(handle, type_code):

0 commit comments

Comments
 (0)