Skip to content

Commit

Permalink
[RUNTIME] Ensure NDArray.CopyTo(Device) always sync (#16716)
Browse files Browse the repository at this point in the history
This PR ensures that NDArray.CopyTo(Device) always sync.
Prior to this PR, the behavior is uncertain as the underlying
DeviceAPI may or maynot sync. This PR further clarifies in
docs about the contract (that low-level device api is always async)
as well as the sync/async nature of each NDArray API.
  • Loading branch information
tqchen authored Mar 14, 2024
1 parent af0c038 commit 071fb8a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
2 changes: 2 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class TVM_DLL DeviceAPI {
* \param from The source array.
* \param to The target array.
* \param stream Optional stream object.
* \note The copy may happen asynchronously if it involves a GPU context.
* Call StreamSync to ensure the copy completes from host's pov.
*/
virtual void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream);
/*!
Expand Down
12 changes: 2 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,9 @@ class NDArray : public ObjectRef {
* \param dev The target device.
* \param mem_scope The memory scope of the target array.
* \return The array under another device.
* \note The copy always triggers a TVMSynchronize.
*/
inline NDArray CopyTo(const Device& dev, Optional<String> mem_scope = NullOpt) const;
TVM_DLL NDArray CopyTo(const Device& dev, Optional<String> mem_scope = NullOpt) const;
/*!
* \brief Load NDArray from stream
* \param stream The input data stream
Expand Down Expand Up @@ -399,15 +400,6 @@ inline void NDArray::CopyTo(const NDArray& other) const {
CopyFromTo(&(get_mutable()->dl_tensor), &(other.get_mutable()->dl_tensor));
}

inline NDArray NDArray::CopyTo(const Device& dev, Optional<String> mem_scope) const {
ICHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret =
Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev, mem_scope);
this->CopyTo(ret);
return ret;
}

inline int NDArray::use_count() const { return data_.use_count(); }

inline const DLTensor* NDArray::operator->() const { return &(get_mutable()->dl_tensor); }
Expand Down
11 changes: 11 additions & 0 deletions src/runtime/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,17 @@ void NDArray::CopyFromBytes(const void* data, size_t nbytes) {
ArrayCopyFromBytes(&get_mutable()->dl_tensor, data, nbytes);
}

NDArray NDArray::CopyTo(const Device& dev, Optional<String> mem_scope) const {
ICHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret =
Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev, mem_scope);
this->CopyTo(ret);
Device copy_gpu_dev = dptr->device.device_type != kDLCPU ? dptr->device : dev;
DeviceAPI::Get(copy_gpu_dev)->StreamSync(copy_gpu_dev, nullptr);
return ret;
}

void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
size_t from_size = GetDataSize(*from);
size_t to_size = GetDataSize(*to);
Expand Down

0 comments on commit 071fb8a

Please sign in to comment.