Skip to content

Commit eabae30

Browse files
authored
[Rust] Fix memory leak #2 (apache#8725)
* Add C++ API for computing type key from type index * Try and isolate leak * Rewrite the bindings to fix the ArgValue lifetime issue There are still quite a few issues left to resolve in this patch, but I believe the runtime changes stablize memory consumption as long as the parameters are only set once. ByteArray also has some totally broken unsafe code which I am unsure of how it was introduced. * Finish handling tvm-rt issues due to ArgValue lifetime This patch further refactors the bindings to better handle the lifetime issues introduced by detecting the argument memory leak. * WIP memory leak * There is issue using TVMCb function which is breaking refcount * Fix fallout from the lifetime refactor * Another tweak * Follow up work from the memory leak, attempt to clean up ByteArray * Add some todos for future work * Fix doc string * Clean up the changes * Format
1 parent e883dcb commit eabae30

File tree

22 files changed

+389
-192
lines changed

22 files changed

+389
-192
lines changed

include/tvm/runtime/c_runtime_api.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,14 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex);
520520
*/
521521
TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex);
522522

523+
/*!
524+
* \brief Convert type index to type key.
525+
* \param tindex The type index.
526+
* \param out_type_key The output type key.
527+
* \return 0 when success, nonzero when failure happens
528+
*/
529+
TVM_DLL int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key);
530+
523531
/*!
524532
* \brief Increase the reference count of an object.
525533
*

rust/tvm-macros/src/object.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,20 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream {
147147
}
148148
}
149149

150-
impl<'a> From<#ref_id> for #tvm_rt_crate::ArgValue<'a> {
151-
fn from(object_ref: #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
150+
impl<'a> From<&'a #ref_id> for #tvm_rt_crate::ArgValue<'a> {
151+
fn from(object_ref: &'a #ref_id) -> #tvm_rt_crate::ArgValue<'a> {
152152
use std::ffi::c_void;
153153
let object_ptr = &object_ref.0;
154154
match object_ptr {
155155
None => {
156156
#tvm_rt_crate::ArgValue::
157157
ObjectHandle(std::ptr::null::<c_void>() as *mut c_void)
158158
}
159-
Some(value) => value.clone().into()
159+
Some(value) => value.into()
160160
}
161161
}
162162
}
163163

164-
impl<'a> From<&#ref_id> for #tvm_rt_crate::ArgValue<'a> {
165-
fn from(object_ref: &#ref_id) -> #tvm_rt_crate::ArgValue<'a> {
166-
let oref: #ref_id = object_ref.clone();
167-
#tvm_rt_crate::ArgValue::<'a>::from(oref)
168-
}
169-
}
170-
171164
impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id {
172165
type Error = #error;
173166

rust/tvm-rt/src/array.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,22 @@ external! {
4545
fn array_size(array: ObjectRef) -> i64;
4646
}
4747

48-
impl<T: IsObjectRef> IsObjectRef for Array<T> {
48+
impl<T: IsObjectRef + 'static> IsObjectRef for Array<T> {
4949
type Object = Object;
5050
fn as_ptr(&self) -> Option<&ObjectPtr<Self::Object>> {
5151
self.object.as_ptr()
5252
}
53+
5354
fn into_ptr(self) -> Option<ObjectPtr<Self::Object>> {
5455
self.object.into_ptr()
5556
}
57+
5658
fn from_ptr(object_ptr: Option<ObjectPtr<Self::Object>>) -> Self {
5759
let object_ref = match object_ptr {
5860
Some(o) => o.into(),
5961
_ => panic!(),
6062
};
63+
6164
Array {
6265
object: object_ref,
6366
_data: PhantomData,
@@ -67,7 +70,7 @@ impl<T: IsObjectRef> IsObjectRef for Array<T> {
6770

6871
impl<T: IsObjectRef> Array<T> {
6972
pub fn from_vec(data: Vec<T>) -> Result<Array<T>> {
70-
let iter = data.into_iter().map(T::into_arg_value).collect();
73+
let iter = data.iter().map(T::into_arg_value).collect();
7174

7275
let func = Function::get("runtime.Array").expect(
7376
"runtime.Array function is not registered, this is most likely a build or linking error",
@@ -151,9 +154,9 @@ impl<T: IsObjectRef> FromIterator<T> for Array<T> {
151154
}
152155
}
153156

154-
impl<'a, T: IsObjectRef> From<Array<T>> for ArgValue<'a> {
155-
fn from(array: Array<T>) -> ArgValue<'a> {
156-
array.object.into()
157+
impl<'a, T: IsObjectRef> From<&'a Array<T>> for ArgValue<'a> {
158+
fn from(array: &'a Array<T>) -> ArgValue<'a> {
159+
(&array.object).into()
157160
}
158161
}
159162

rust/tvm-rt/src/function.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ use std::{
3535

3636
use crate::errors::Error;
3737

38-
pub use super::to_function::{ToFunction, Typed};
38+
pub use super::to_function::{RawArgs, ToFunction, Typed};
39+
use crate::object::AsArgValue;
3940
pub use tvm_sys::{ffi, ArgValue, RetValue};
4041

4142
pub type Result<T> = std::result::Result<T, Error>;
@@ -153,12 +154,12 @@ macro_rules! impl_to_fn {
153154
where
154155
Error: From<Err>,
155156
Out: TryFrom<RetValue, Error = Err>,
156-
$($t: Into<ArgValue<'static>>),*
157+
$($t: for<'a> AsArgValue<'a>),*
157158
{
158159
fn from(func: Function) -> Self {
159160
#[allow(non_snake_case)]
160161
Box::new(move |$($t : $t),*| {
161-
let args = vec![ $($t.into()),* ];
162+
let args = vec![ $((&$t).as_arg_value()),* ];
162163
Ok(func.invoke(args)?.try_into()?)
163164
})
164165
}
@@ -196,8 +197,8 @@ impl TryFrom<RetValue> for Function {
196197
}
197198
}
198199

199-
impl<'a> From<Function> for ArgValue<'a> {
200-
fn from(func: Function) -> ArgValue<'a> {
200+
impl<'a> From<&'a Function> for ArgValue<'a> {
201+
fn from(func: &'a Function) -> ArgValue<'a> {
201202
if func.handle().is_null() {
202203
ArgValue::Null
203204
} else {
@@ -291,12 +292,12 @@ where
291292
}
292293

293294
pub fn register_untyped<S: Into<String>>(
294-
f: fn(Vec<ArgValue<'static>>) -> Result<RetValue>,
295+
f: for<'a> fn(Vec<ArgValue<'a>>) -> Result<RetValue>,
295296
name: S,
296297
override_: bool,
297298
) -> Result<()> {
298-
// TODO(@jroesch): can we unify all the code.
299-
let func = f.to_function();
299+
//TODO(@jroesch): can we unify the untpyed and typed registration functions.
300+
let func = ToFunction::<RawArgs, RetValue>::to_function(f);
300301
let name = name.into();
301302
// Not sure about this code
302303
let handle = func.handle();

rust/tvm-rt/src/graph_rt.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ impl GraphRt {
5050

5151
let runtime_create_fn_ret = runtime_create_fn.invoke(vec![
5252
graph.into(),
53-
lib.into(),
53+
(&lib).into(),
5454
(&dev.device_type).into(),
5555
// NOTE you must pass the device id in as i32 because that's what TVM expects
5656
(dev.device_id as i32).into(),
5757
]);
58+
5859
let graph_executor_module: Module = runtime_create_fn_ret?.try_into()?;
5960
Ok(Self {
6061
module: graph_executor_module,
@@ -79,7 +80,7 @@ impl GraphRt {
7980
pub fn set_input(&mut self, name: &str, input: NDArray) -> Result<()> {
8081
let ref set_input_fn = self.module.get_function("set_input", false)?;
8182

82-
set_input_fn.invoke(vec![name.into(), input.into()])?;
83+
set_input_fn.invoke(vec![name.into(), (&input).into()])?;
8384
Ok(())
8485
}
8586

@@ -101,7 +102,7 @@ impl GraphRt {
101102
/// Extract the ith output from the graph executor and write the results into output.
102103
pub fn get_output_into(&mut self, i: i64, output: NDArray) -> Result<()> {
103104
let get_output_fn = self.module.get_function("get_output", false)?;
104-
get_output_fn.invoke(vec![i.into(), output.into()])?;
105+
get_output_fn.invoke(vec![i.into(), (&output).into()])?;
105106
Ok(())
106107
}
107108
}

rust/tvm-rt/src/lib.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,17 @@ mod tests {
130130
);
131131
}
132132

133-
#[test]
134-
fn bytearray() {
135-
let w = vec![1u8, 2, 3, 4, 5];
136-
let v = ByteArray::from(w.as_slice());
137-
let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
138-
assert_eq!(
139-
tvm.data(),
140-
w.iter().copied().collect::<Vec<u8>>().as_slice()
141-
);
142-
}
133+
// todo(@jroesch): #8800 Follow up with ByteArray RetValue ownership.
134+
// #[test]
135+
// fn bytearray() {
136+
// let w = vec![1u8, 2, 3, 4, 5];
137+
// let v = ByteArray::from(w.as_slice());
138+
// let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
139+
// assert_eq!(
140+
// tvm.data(),
141+
// w.iter().copied().collect::<Vec<u8>>().as_slice()
142+
// );
143+
// }
143144

144145
#[test]
145146
fn ty() {

rust/tvm-rt/src/map.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,18 @@ external! {
5858
fn map_items(map: ObjectRef) -> Array<ObjectRef>;
5959
}
6060

61-
impl<K, V> FromIterator<(K, V)> for Map<K, V>
61+
impl<'a, K: 'a, V: 'a> FromIterator<(&'a K, &'a V)> for Map<K, V>
6262
where
6363
K: IsObjectRef,
6464
V: IsObjectRef,
6565
{
66-
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
66+
fn from_iter<T: IntoIterator<Item = (&'a K, &'a V)>>(iter: T) -> Self {
6767
let iter = iter.into_iter();
6868
let (lower_bound, upper_bound) = iter.size_hint();
6969
let mut buffer: Vec<ArgValue> = Vec::with_capacity(upper_bound.unwrap_or(lower_bound) * 2);
7070
for (k, v) in iter {
71-
buffer.push(k.into());
72-
buffer.push(v.into())
71+
buffer.push(k.into_arg_value());
72+
buffer.push(v.into_arg_value());
7373
}
7474
Self::from_data(buffer).expect("failed to convert from data")
7575
}
@@ -202,13 +202,13 @@ where
202202
}
203203
}
204204

205-
impl<'a, K, V> From<Map<K, V>> for ArgValue<'a>
205+
impl<'a, K, V> From<&'a Map<K, V>> for ArgValue<'a>
206206
where
207207
K: IsObjectRef,
208208
V: IsObjectRef,
209209
{
210-
fn from(map: Map<K, V>) -> ArgValue<'a> {
211-
map.object.into()
210+
fn from(map: &'a Map<K, V>) -> ArgValue<'a> {
211+
(&map.object).into()
212212
}
213213
}
214214

@@ -268,7 +268,7 @@ mod test {
268268
let mut std_map: HashMap<TString, TString> = HashMap::new();
269269
std_map.insert("key1".into(), "value1".into());
270270
std_map.insert("key2".into(), "value2".into());
271-
let tvm_map = Map::from_iter(std_map.clone().into_iter());
271+
let tvm_map = Map::from_iter(std_map.iter());
272272
let back_map = tvm_map.into();
273273
assert_eq!(std_map, back_map);
274274
}

rust/tvm-rt/src/ndarray.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ impl NDArrayContainer {
101101
.cast::<NDArrayContainer>()
102102
}
103103
}
104+
105+
pub fn as_mut_ptr<'a>(object_ptr: &ObjectPtr<NDArrayContainer>) -> *mut NDArrayContainer
106+
where
107+
NDArrayContainer: 'a,
108+
{
109+
let base_offset = memoffset::offset_of!(NDArrayContainer, dl_tensor) as isize;
110+
unsafe {
111+
object_ptr
112+
.ptr
113+
.as_ptr()
114+
.cast::<u8>()
115+
.offset(base_offset)
116+
.cast::<NDArrayContainer>()
117+
}
118+
}
104119
}
105120

106121
fn cow_usize<'a>(slice: &[i64]) -> Cow<'a, [usize]> {

rust/tvm-rt/src/object/mod.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ mod object_ptr;
2929

3030
pub use object_ptr::{IsObject, Object, ObjectPtr, ObjectRef};
3131

32+
pub trait AsArgValue<'a> {
33+
fn as_arg_value(&'a self) -> ArgValue<'a>;
34+
}
35+
36+
impl<'a, T: 'static> AsArgValue<'a> for T
37+
where
38+
&'a T: Into<ArgValue<'a>>,
39+
{
40+
fn as_arg_value(&'a self) -> ArgValue<'a> {
41+
self.into()
42+
}
43+
}
44+
3245
// TODO we would prefer to blanket impl From/TryFrom ArgValue/RetValue, but we
3346
// can't because of coherence rules. Instead, we generate them in the macro, and
3447
// add what we can (including Into instead of From) as subtraits.
@@ -37,8 +50,8 @@ pub trait IsObjectRef:
3750
Sized
3851
+ Clone
3952
+ Into<RetValue>
53+
+ for<'a> AsArgValue<'a>
4054
+ TryFrom<RetValue, Error = Error>
41-
+ for<'a> Into<ArgValue<'a>>
4255
+ for<'a> TryFrom<ArgValue<'a>, Error = Error>
4356
+ std::fmt::Debug
4457
{
@@ -51,8 +64,8 @@ pub trait IsObjectRef:
5164
Self::from_ptr(None)
5265
}
5366

54-
fn into_arg_value<'a>(self) -> ArgValue<'a> {
55-
self.into()
67+
fn into_arg_value<'a>(&'a self) -> ArgValue<'a> {
68+
self.as_arg_value()
5669
}
5770

5871
fn from_arg_value<'a>(arg_value: ArgValue<'a>) -> Result<Self, Error> {

0 commit comments

Comments
 (0)