Skip to content

Commit 4dc673b

Browse files
committed
* GetDataSize is moved to DeviceAPI and memory_manager uses this interface.
1 parent b7bb9f4 commit 4dc673b

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

include/tvm/runtime/device_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,14 @@ class TVM_DLL DeviceAPI {
9595
*/
9696
virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0;
9797

98+
/*!
99+
* \brief Get the physical memory size required.
100+
* \param arr the tensor object.
101+
* \param mem_scope the memory scope if any
102+
* \return the memory size.
103+
*/
104+
virtual size_t GetDataSize(const DLTensor& arr, Optional<String> mem_scope = NullOpt);
105+
98106
/*!
99107
* \brief Query the device for specified properties.
100108
*

src/runtime/c_runtime_api.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,20 @@ static size_t GetDataAlignment(const DLDataType dtype) {
152152
return align;
153153
}
154154

155+
size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional<String> mem_scope) {
156+
if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") {
157+
size_t size = 1;
158+
for (tvm_index_t i = 0; i < arr.ndim; ++i) {
159+
size *= static_cast<size_t>(arr.shape[i]);
160+
}
161+
size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;
162+
return size;
163+
}
164+
LOG(FATAL) << "Device does not support physical mem computation with "
165+
<< "specified memory scope: " << mem_scope.value();
166+
return 0;
167+
}
168+
155169
void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
156170
Optional<String> mem_scope) {
157171
if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") {

src/runtime/memory/memory_manager.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ NDArray StorageObj::AllocNDArray(size_t offset, ShapeTuple shape, DLDataType dty
8585
container->dl_tensor.byte_offset = offset;
8686

8787
container->SetDeleter(StorageObj::Deleter);
88-
size_t needed_size = GetDataSize(container->dl_tensor);
88+
size_t needed_size = DeviceAPI::Get(this->buffer.device)->GetDataSize(container->dl_tensor);
8989
this->IncRef();
9090
// The manager context pointer must continue to point to the storage object
9191
// which owns the backing memory, and keeps track of the reference count.
@@ -159,7 +159,7 @@ NDArray Allocator::Empty(ShapeTuple shape, DLDataType dtype, DLDevice dev,
159159
VerifyDataType(dtype);
160160
NDArray::Container* container = new NDArray::Container(nullptr, shape, dtype, dev);
161161
container->SetDeleter(BufferDeleter);
162-
size_t size = GetDataSize(container->dl_tensor);
162+
size_t size = DeviceAPI::Get(dev)->GetDataSize(container->dl_tensor);
163163
size_t alignment = GetDataAlignment(container->dl_tensor);
164164
Buffer* buffer = new Buffer;
165165
if (!mem_scope.defined() || mem_scope == "global") {
@@ -177,7 +177,7 @@ Buffer Allocator::Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
177177
if (mem_scope.empty() || mem_scope == "global") {
178178
// by default, we can always redirect to the flat memory allocations
179179
NDArray::Container container(nullptr, shape, type_hint, dev);
180-
size_t size = GetDataSize(container.dl_tensor);
180+
size_t size = DeviceAPI::Get(dev)->GetDataSize(container.dl_tensor);
181181
size_t alignment = GetDataAlignment(container.dl_tensor);
182182
return Alloc(size, alignment, type_hint);
183183
}

src/runtime/memory/naive_allocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class NaiveAllocator final : public Allocator {
5151
Buffer buf;
5252
size_t nbytes = 1;
5353
buf.shape = shape;
54-
for (int i = 0; i < shape.size(); ++i) {
54+
for (int i = 0; i < static_cast<int>(shape.size()); ++i) {
5555
nbytes *= static_cast<size_t>(shape[i]);
5656
}
5757
nbytes *= (type_hint.bits * type_hint.lanes + 7) / 8;

0 commit comments

Comments
 (0)