Skip to content

Commit 6255b2f

Browse files
icemelonTrevor Morris
authored andcommitted
[Relay] Fix memory leak when accessing NDArray (apache#5413)
1 parent 55a173a commit 6255b2f

File tree

4 files changed

+39
-56
lines changed

4 files changed

+39
-56
lines changed

src/relay/backend/contrib/codegen_c/codegen.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
6868

6969
runtime::NDArray array = cn->data;
7070
const auto& shape = array.Shape();
71-
const DLTensor& dl_tensor = array.ToDLPack()->dl_tensor;
7271

7372
// Get the number of elements.
7473
int64_t num_elems = 1;
@@ -83,11 +82,11 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code
8382
// to avoid possible stack overflow.
8483
buf_stream << dtype << " " << output.name << "[" << num_elems << "] = {";
8584
if (dtype == "float") {
86-
float* p_flt = static_cast<float*>(dl_tensor.data);
85+
float* p_flt = static_cast<float*>(array->data);
8786
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
8887
if (num_elems) buf_stream << p_flt[num_elems - 1];
8988
} else if (dtype == "int") {
90-
int* p_flt = static_cast<int*>(dl_tensor.data);
89+
int* p_flt = static_cast<int*>(array->data);
9190
for (int64_t i = 0; i < num_elems - 1; i++) buf_stream << p_flt[i] << ", ";
9291
if (num_elems) buf_stream << p_flt[num_elems - 1];
9392
} else {

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class CodegenDNNL : public MemoizedExprTranslator<std::vector<Output>>, public C
169169
CHECK_EQ(GetDtypeString(type_node), "float") << "Only float is supported for now.";
170170

171171
std::ostringstream buf_stream;
172-
const float* ptr = static_cast<float*>(array.ToDLPack()->dl_tensor.data);
172+
const float* ptr = static_cast<float*>(array->data);
173173

174174
// Allocate large arrays on the static section to avoid stakc overflow.
175175
// Note that this would probably increase compilation time as the source

src/relay/backend/vm/compiler.cc

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -193,35 +193,26 @@ TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause
193193
return else_branch;
194194
}
195195

196-
std::vector<int64_t> ToAllocTensorShape64(NDArray shape) {
196+
std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
197197
std::vector<int64_t> raw_shape;
198-
DLTensor tensor = shape.ToDLPack()->dl_tensor;
199-
CHECK_EQ(tensor.ndim, 1u);
200-
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
201-
202-
// TODO(@jroesch): we really need to standaridize the bit width of
203-
// all of the shape manipulating code.
204-
CHECK_EQ(tensor.dtype.bits, 64) << "found " << tensor.dtype.bits;
205-
int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
206-
for (auto i = 0; i < tensor.shape[0]; i++) {
207-
raw_shape.push_back(int_ptr[i]);
208-
}
209-
return raw_shape;
210-
}
211-
212-
213-
std::vector<int64_t> ToAllocTensorShape32(NDArray shape) {
214-
std::vector<int64_t> raw_shape;
215-
DLTensor tensor = shape.ToDLPack()->dl_tensor;
216-
CHECK_EQ(tensor.ndim, 1u);
217-
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
218-
219-
// TODO(@jroesch): we really need to standaridize the bit width of
220-
// all of the shape manipulating code.
221-
CHECK_LE(tensor.dtype.bits, 32) << "found " << tensor.dtype.bits;
222-
int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
223-
for (auto i = 0; i < tensor.shape[0]; i++) {
224-
raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
198+
CHECK_EQ(shape->ndim, 1u);
199+
CHECK_EQ(shape->dtype.code, 0U)
200+
<< "The dtype of constant shape must be int32 or int64, but got "
201+
<< DLDataType2String(shape->dtype);
202+
CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
203+
<< "The dtype of constant shape must be int32 or int64, but got"
204+
<< DLDataType2String(shape->dtype);
205+
206+
if (shape->dtype.bits == 64) {
207+
int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data);
208+
for (auto i = 0; i < shape->shape[0]; i++) {
209+
raw_shape.push_back(int_ptr[i]);
210+
}
211+
} else { // int32
212+
int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data);
213+
for (auto i = 0; i < shape->shape[0]; i++) {
214+
raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
215+
}
225216
}
226217
return raw_shape;
227218
}
@@ -546,17 +537,8 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
546537

547538
if (const_shape) {
548539
NDArray shape = const_shape->data;
549-
std::vector<int64_t> raw_shape;
550-
DLTensor tensor = shape.ToDLPack()->dl_tensor;
551-
// TODO(@jroesch): we need to get an RFC done to standarize this
552-
if (tensor.dtype.bits == 64) {
553-
raw_shape = ToAllocTensorShape64(shape);
554-
} else if (tensor.dtype.bits == 32) {
555-
raw_shape = ToAllocTensorShape32(shape);
556-
} else {
557-
LOG(FATAL) << "unsupported bitwidth: " << tensor.dtype.bits;
558-
}
559-
540+
// TODO(@jroesch): we need to get an RFC done to standarize shape dtype
541+
std::vector<int64_t> raw_shape = ToAllocTensorShape(shape);
560542
// Add context field.
561543
Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister()));
562544
} else {

src/relay/op/memory/memory.cc

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
*/
2424

2525
#include <topi/elemwise.h>
26+
#include <tvm/runtime/data_type.h>
2627
#include <tvm/relay/attrs/memory.h>
2728
#include <tvm/relay/expr.h>
2829
#include <tvm/relay/op.h>
@@ -107,21 +108,22 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
107108
std::vector<int64_t> FromConstShape(Constant konst) {
108109
runtime::NDArray shape = konst->data;
109110
std::vector<int64_t> raw_shape;
110-
DLTensor tensor = shape.ToDLPack()->dl_tensor;
111-
CHECK_EQ(tensor.ndim, 1u);
112-
CHECK_EQ(tensor.dtype.code, 0U) << "found " << tensor.dtype.code;
113-
114-
CHECK(tensor.dtype.bits == 64 || tensor.dtype.bits == 32)
115-
<< "found " << static_cast<int>(tensor.dtype.bits);
116-
117-
if (tensor.dtype.bits == 32) {
118-
const int32_t* int_ptr = reinterpret_cast<int32_t*>(tensor.data);
119-
for (auto i = 0; i < tensor.shape[0]; i++) {
111+
CHECK_EQ(shape->ndim, 1u);
112+
CHECK_EQ(shape->dtype.code, 0U)
113+
<< "The dtype of constant shape must be int32 or int64, but got "
114+
<< runtime::DLDataType2String(shape->dtype);
115+
CHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
116+
<< "The dtype of constant shape must be int32 or int64, but got"
117+
<< runtime::DLDataType2String(shape->dtype);
118+
119+
if (shape->dtype.bits == 32) {
120+
const int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data);
121+
for (auto i = 0; i < shape->shape[0]; i++) {
120122
raw_shape.push_back(int_ptr[i]);
121123
}
122-
} else if (tensor.dtype.bits == 64) {
123-
const int64_t* int_ptr = reinterpret_cast<int64_t*>(tensor.data);
124-
for (auto i = 0; i < tensor.shape[0]; i++) {
124+
} else if (shape->dtype.bits == 64) {
125+
const int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data);
126+
for (auto i = 0; i < shape->shape[0]; i++) {
125127
raw_shape.push_back(int_ptr[i]);
126128
}
127129
}

0 commit comments

Comments
 (0)