Skip to content

Commit 6b3e7a0

Browse files
committed
feat: threadpool access in operator kernels
1 parent 41ef65a commit 6b3e7a0

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/operator/kernel.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
ffi::{c_char, CString},
2+
ffi::{c_char, c_void, CString},
33
ops::{Deref, DerefMut},
44
ptr::{self, NonNull}
55
};
@@ -249,6 +249,15 @@ impl KernelContext {
249249
Ok(NonNull::new(resource_ptr))
250250
}
251251

252+
pub fn par_for<F>(&self, total: usize, max_num_batches: usize, f: F) -> Result<()>
253+
where
254+
F: Fn(usize) + Sync + Send
255+
{
256+
let executor = Box::new(f) as Box<dyn Fn(usize) + Sync + Send>;
257+
ortsys![unsafe KernelContext_ParallelFor(self.ptr.as_ptr(), Some(parallel_for_cb), total as _, max_num_batches as _, &executor as *const _ as *mut c_void)?];
258+
Ok(())
259+
}
260+
252261
// TODO: STATUS_ACCESS_VIOLATION inside `KernelContext_GetScratchBuffer`. gonna assume this one is just an internal ONNX
253262
// Runtime bug.
254263
//
@@ -280,3 +289,8 @@ impl KernelContext {
280289
Ok(NonNull::new(stream_ptr))
281290
}
282291
}
292+
293+
extern "C" fn parallel_for_cb(user_data: *mut c_void, iterator: ort_sys::size_t) {
294+
let executor = unsafe { &*user_data.cast::<Box<dyn Fn(usize) + Sync + Send>>() };
295+
executor(iterator as _)
296+
}

tools/api-coverage.ts

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const IGNORED_SYMBOLS = new Set<string>([
1414
'RegisterCustomOpsUsingFunction',
1515
'SessionOptionsAppendExecutionProvider_CUDA', // we use V2
1616
'SessionOptionsAppendExecutionProvider_TensorRT', // we use V2
17+
'GetValueType', // we get value types via GetTypeInfo -> GetOnnxTypeFromTypeInfo, which is equivalent
1718
'SetLanguageProjection', // someday we shall have `ORT_PROJECTION_RUST`, but alas, today is not that day...
1819

1920
// we use allocator APIs directly on the Allocator struct

0 commit comments

Comments
 (0)