Skip to content

Commit cf1be86

Browse files
committed
refactor: remove lifetime bound for IoBinding
And implement `Send` for it. yw audioxd
1 parent bd3c891 commit cf1be86

File tree

2 files changed

+43
-17
lines changed

2 files changed

+43
-17
lines changed

src/io_binding.rs

+42-16
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use crate::{
1212
memory::MemoryInfo,
1313
ortsys,
1414
session::{output::SessionOutputs, NoSelectedOutputs, RunOptions, Session},
15-
value::{DynValue, Value, ValueInner, ValueTypeMarker}
15+
value::{DynValue, Value, ValueInner, ValueTypeMarker},
16+
SharedSessionInner
1617
};
1718

1819
/// Enables binding of session inputs and/or outputs to pre-allocated memory.
@@ -86,21 +87,21 @@ use crate::{
8687
/// of `unet.run()`, and this copying can come with significant latency & overhead. With [`IoBinding`], the `condition`
8788
/// tensor is only copied to the device once instead of 20 times.
8889
#[derive(Debug)]
89-
pub struct IoBinding<'s> {
90+
pub struct IoBinding {
9091
pub(crate) ptr: NonNull<ort_sys::OrtIoBinding>,
91-
session: &'s Session,
9292
held_inputs: HashMap<String, Arc<ValueInner>>,
9393
output_names: Vec<String>,
94-
output_values: HashMap<String, DynValue>
94+
output_values: HashMap<String, DynValue>,
95+
session: Arc<SharedSessionInner>
9596
}
9697

97-
impl<'s> IoBinding<'s> {
98-
pub(crate) fn new(session: &'s Session) -> Result<Self> {
98+
impl IoBinding {
99+
pub(crate) fn new(session: &Session) -> Result<Self> {
99100
let mut ptr: *mut ort_sys::OrtIoBinding = ptr::null_mut();
100101
ortsys![unsafe CreateIoBinding(session.inner.session_ptr.as_ptr(), &mut ptr)?; nonNull(ptr)];
101102
Ok(Self {
102103
ptr: unsafe { NonNull::new_unchecked(ptr) },
103-
session,
104+
session: session.inner(),
104105
held_inputs: HashMap::new(),
105106
output_names: Vec::new(),
106107
output_values: HashMap::new()
@@ -177,24 +178,23 @@ impl<'s> IoBinding<'s> {
177178
}
178179

179180
/// Performs inference on the session using the bound inputs specified by [`IoBinding::bind_input`].
180-
pub fn run_with_options(&mut self, run_options: &RunOptions<NoSelectedOutputs>) -> Result<SessionOutputs<'_, 's>> {
181+
pub fn run_with_options(&mut self, run_options: &RunOptions<NoSelectedOutputs>) -> Result<SessionOutputs<'_, '_>> {
181182
self.run_inner(Some(run_options))
182183
}
183184

184-
fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, 's>> {
185+
fn run_inner(&mut self, run_options: Option<&RunOptions<NoSelectedOutputs>>) -> Result<SessionOutputs<'_, '_>> {
185186
let run_options_ptr = if let Some(run_options) = run_options {
186187
run_options.run_options_ptr.as_ptr()
187188
} else {
188189
std::ptr::null_mut()
189190
};
190-
ortsys![unsafe RunWithBinding(self.session.inner.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr())?];
191+
ortsys![unsafe RunWithBinding(self.session.session_ptr.as_ptr(), run_options_ptr, self.ptr.as_ptr())?];
191192

192193
let owned_ptrs: HashMap<*mut ort_sys::OrtValue, &Arc<ValueInner>> = self.output_values.values().map(|c| (c.ptr(), &c.inner)).collect();
193194
let mut count = self.output_names.len() as ort_sys::size_t;
194195
if count > 0 {
195196
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
196-
let allocator = self.session.allocator();
197-
ortsys![unsafe GetBoundOutputValues(self.ptr.as_ptr(), allocator.ptr.as_ptr(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];
197+
ortsys![unsafe GetBoundOutputValues(self.ptr.as_ptr(), self.session.allocator.ptr.as_ptr(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];
198198

199199
let output_values = unsafe { std::slice::from_raw_parts(output_values_ptr, count as _).to_vec() }
200200
.into_iter()
@@ -207,21 +207,23 @@ impl<'s> IoBinding<'s> {
207207
} else {
208208
DynValue::from_ptr(
209209
NonNull::new(v).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"),
210-
Some(Arc::clone(&self.session.inner))
210+
Some(Arc::clone(&self.session))
211211
)
212212
}
213213
});
214214

215215
// output values will be freed when the `Value`s in `SessionOutputs` drop
216216

217-
Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, allocator, output_values_ptr.cast()))
217+
Ok(SessionOutputs::new_backed(self.output_names.iter().map(String::as_str), output_values, &self.session.allocator, output_values_ptr.cast()))
218218
} else {
219219
Ok(SessionOutputs::new_empty())
220220
}
221221
}
222222
}
223223

224-
impl<'s> Drop for IoBinding<'s> {
224+
unsafe impl Send for IoBinding {}
225+
226+
impl Drop for IoBinding {
225227
fn drop(&mut self) {
226228
ortsys![unsafe ReleaseIoBinding(self.ptr.as_ptr())];
227229
}
@@ -295,7 +297,31 @@ mod tests {
295297
}
296298

297299
#[test]
298-
fn test_mnist_clears_bound() -> Result<()> {
300+
fn test_send_iobinding() -> Result<()> {
301+
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;
302+
303+
let array = get_image();
304+
305+
let mut binding = session.create_binding()?;
306+
let output = Array2::from_shape_simple_fn((1, 10), || 0.0_f32);
307+
binding.bind_output(&session.outputs[0].name, Tensor::from_array(output)?)?;
308+
309+
let probabilities = std::thread::spawn(move || {
310+
binding.bind_input(&session.inputs[0].name, &Tensor::from_array(array)?)?;
311+
let outputs = binding.run()?;
312+
let probabilities = extract_probabilities(&outputs[0])?;
313+
Ok::<Vec<(usize, f32)>, crate::Error>(probabilities)
314+
})
315+
.join()
316+
.expect("")?;
317+
318+
assert_eq!(probabilities[0].0, 5);
319+
320+
Ok(())
321+
}
322+
323+
#[test]
324+
fn test_mnist_clear_bounds() -> Result<()> {
299325
let session = Session::builder()?.commit_from_url("https://parcel.pyke.io/v2/cdn/assetdelivery/ortrsv2/ex_models/mnist.onnx")?;
300326

301327
let array = get_image();

src/session/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pub use self::{
4343
#[derive(Debug)]
4444
pub struct SharedSessionInner {
4545
pub(crate) session_ptr: NonNull<ort_sys::OrtSession>,
46-
allocator: Allocator,
46+
pub(crate) allocator: Allocator,
4747
/// Additional things we may need to hold onto for the duration of this session, like [`crate::OperatorDomain`]s and
4848
/// DLL handles for operator libraries.
4949
_extras: Vec<Box<dyn Any>>,

0 commit comments

Comments
 (0)