-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Feature
Currently the ONNX backend in wasmtime-wasi-nn only uses the default CPU execution provider and ignores the ExecutionTarget requested by the WASM caller.
wasmtime/crates/wasi-nn/src/backend/onnxruntime.rs
Lines 21 to 33 in 24c1388
| fn load(&mut self, builders: &[&[u8]], target: ExecutionTarget) -> Result<Graph, BackendError> { | |
| if builders.len() != 1 { | |
| return Err(BackendError::InvalidNumberOfBuilders(1, builders.len()).into()); | |
| } | |
| let session = Session::builder()? | |
| .with_optimization_level(GraphOptimizationLevel::Level3)? | |
| .with_model_from_memory(builders[0])?; | |
| let box_: Box<dyn BackendGraph> = | |
| Box::new(ONNXGraph(Arc::new(Mutex::new(session)), target)); | |
| Ok(box_.into()) | |
| } |
I would like to suggest adding support for additional execution providers (CUDA, TensorRT, ROCm, ...) to wasmtime-wasi-nn.
Benefit
Improved performance for WASM modules using the wasi-nn API.
Implementation
ort already has support for many execution providers, so integrating these into wasmtime-wasi-nn should not be to much work.
I would be interested in looking into this, however, I only really have the means to test the DirectML and NVIDIA CUDA / TensorRT EPs.
Alternatives
Leave it to the users to add support for additional execution providers.