@@ -12,7 +12,8 @@ use crate::{
12
12
memory:: MemoryInfo ,
13
13
ortsys,
14
14
session:: { output:: SessionOutputs , NoSelectedOutputs , RunOptions , Session } ,
15
- value:: { DynValue , Value , ValueInner , ValueTypeMarker }
15
+ value:: { DynValue , Value , ValueInner , ValueTypeMarker } ,
16
+ SharedSessionInner
16
17
} ;
17
18
18
19
/// Enables binding of session inputs and/or outputs to pre-allocated memory.
@@ -86,21 +87,21 @@ use crate::{
86
87
/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition`
87
88
/// tensor is only copied to the device once instead of 20 times.
88
89
#[ derive( Debug ) ]
89
- pub struct IoBinding < ' s > {
90
+ pub struct IoBinding {
90
91
pub ( crate ) ptr : NonNull < ort_sys:: OrtIoBinding > ,
91
- session : & ' s Session ,
92
92
held_inputs : HashMap < String , Arc < ValueInner > > ,
93
93
output_names : Vec < String > ,
94
- output_values : HashMap < String , DynValue >
94
+ output_values : HashMap < String , DynValue > ,
95
+ session : Arc < SharedSessionInner >
95
96
}
96
97
97
- impl < ' s > IoBinding < ' s > {
98
- pub ( crate ) fn new ( session : & ' s Session ) -> Result < Self > {
98
+ impl IoBinding {
99
+ pub ( crate ) fn new ( session : & Session ) -> Result < Self > {
99
100
let mut ptr: * mut ort_sys:: OrtIoBinding = ptr:: null_mut ( ) ;
100
101
ortsys ! [ unsafe CreateIoBinding ( session. inner. session_ptr. as_ptr( ) , & mut ptr) ?; nonNull( ptr) ] ;
101
102
Ok ( Self {
102
103
ptr : unsafe { NonNull :: new_unchecked ( ptr) } ,
103
- session,
104
+ session : session . inner ( ) ,
104
105
held_inputs : HashMap :: new ( ) ,
105
106
output_names : Vec :: new ( ) ,
106
107
output_values : HashMap :: new ( )
@@ -177,24 +178,23 @@ impl<'s> IoBinding<'s> {
177
178
}
178
179
179
180
/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
180
- pub fn run_with_options ( & mut self , run_options : & RunOptions < NoSelectedOutputs > ) -> Result < SessionOutputs < ' _ , ' s > > {
181
+ pub fn run_with_options ( & mut self , run_options : & RunOptions < NoSelectedOutputs > ) -> Result < SessionOutputs < ' _ , ' _ > > {
181
182
self . run_inner ( Some ( run_options) )
182
183
}
183
184
184
- fn run_inner ( & mut self , run_options : Option < & RunOptions < NoSelectedOutputs > > ) -> Result < SessionOutputs < ' _ , ' s > > {
185
+ fn run_inner ( & mut self , run_options : Option < & RunOptions < NoSelectedOutputs > > ) -> Result < SessionOutputs < ' _ , ' _ > > {
185
186
let run_options_ptr = if let Some ( run_options) = run_options {
186
187
run_options. run_options_ptr . as_ptr ( )
187
188
} else {
188
189
std:: ptr:: null_mut ( )
189
190
} ;
190
- ortsys ! [ unsafe RunWithBinding ( self . session. inner . session_ptr. as_ptr( ) , run_options_ptr, self . ptr. as_ptr( ) ) ?] ;
191
+ ortsys ! [ unsafe RunWithBinding ( self . session. session_ptr. as_ptr( ) , run_options_ptr, self . ptr. as_ptr( ) ) ?] ;
191
192
192
193
let owned_ptrs: HashMap < * mut ort_sys:: OrtValue , & Arc < ValueInner > > = self . output_values . values ( ) . map ( |c| ( c. ptr ( ) , & c. inner ) ) . collect ( ) ;
193
194
let mut count = self . output_names . len ( ) as ort_sys:: size_t ;
194
195
if count > 0 {
195
196
let mut output_values_ptr: * mut * mut ort_sys:: OrtValue = ptr:: null_mut ( ) ;
196
- let allocator = self . session . allocator ( ) ;
197
- ortsys ! [ unsafe GetBoundOutputValues ( self . ptr. as_ptr( ) , allocator. ptr. as_ptr( ) , & mut output_values_ptr, & mut count) ?; nonNull( output_values_ptr) ] ;
197
+ ortsys ! [ unsafe GetBoundOutputValues ( self . ptr. as_ptr( ) , self . session. allocator. ptr. as_ptr( ) , & mut output_values_ptr, & mut count) ?; nonNull( output_values_ptr) ] ;
198
198
199
199
let output_values = unsafe { std:: slice:: from_raw_parts ( output_values_ptr, count as _ ) . to_vec ( ) }
200
200
. into_iter ( )
@@ -207,21 +207,23 @@ impl<'s> IoBinding<'s> {
207
207
} else {
208
208
DynValue :: from_ptr (
209
209
NonNull :: new ( v) . expect ( "OrtValue ptrs returned by GetBoundOutputValues should not be null" ) ,
210
- Some ( Arc :: clone ( & self . session . inner ) )
210
+ Some ( Arc :: clone ( & self . session ) )
211
211
)
212
212
}
213
213
} ) ;
214
214
215
215
// output values will be freed when the `Value`s in `SessionOutputs` drop
216
216
217
- Ok ( SessionOutputs :: new_backed ( self . output_names . iter ( ) . map ( String :: as_str) , output_values, allocator, output_values_ptr. cast ( ) ) )
217
+ Ok ( SessionOutputs :: new_backed ( self . output_names . iter ( ) . map ( String :: as_str) , output_values, & self . session . allocator , output_values_ptr. cast ( ) ) )
218
218
} else {
219
219
Ok ( SessionOutputs :: new_empty ( ) )
220
220
}
221
221
}
222
222
}
223
223
224
- impl < ' s > Drop for IoBinding < ' s > {
224
+ unsafe impl Send for IoBinding { }
225
+
226
+ impl Drop for IoBinding {
225
227
fn drop ( & mut self ) {
226
228
ortsys ! [ unsafe ReleaseIoBinding ( self . ptr. as_ptr( ) ) ] ;
227
229
}
@@ -295,7 +297,31 @@ mod tests {
295
297
}
296
298
297
299
#[ test]
298
- fn test_mnist_clears_bound ( ) -> Result < ( ) > {
300
+ fn test_send_iobinding ( ) -> Result < ( ) > {
301
+ let session = Session :: builder ( ) ?. commit_from_url ( "https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx" ) ?;
302
+
303
+ let array = get_image ( ) ;
304
+
305
+ let mut binding = session. create_binding ( ) ?;
306
+ let output = Array2 :: from_shape_simple_fn ( ( 1 , 10 ) , || 0.0_f32 ) ;
307
+ binding. bind_output ( & session. outputs [ 0 ] . name , Tensor :: from_array ( output) ?) ?;
308
+
309
+ let probabilities = std:: thread:: spawn ( move || {
310
+ binding. bind_input ( & session. inputs [ 0 ] . name , & Tensor :: from_array ( array) ?) ?;
311
+ let outputs = binding. run ( ) ?;
312
+ let probabilities = extract_probabilities ( & outputs[ 0 ] ) ?;
313
+ Ok :: < Vec < ( usize , f32 ) > , crate :: Error > ( probabilities)
314
+ } )
315
+ . join ( )
316
+ . expect ( "" ) ?;
317
+
318
+ assert_eq ! ( probabilities[ 0 ] . 0 , 5 ) ;
319
+
320
+ Ok ( ( ) )
321
+ }
322
+
323
+ #[ test]
324
+ fn test_mnist_clear_bounds ( ) -> Result < ( ) > {
299
325
let session = Session :: builder ( ) ?. commit_from_url ( "https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx" ) ?;
300
326
301
327
let array = get_image ( ) ;
0 commit comments