Skip to content

Commit bfa791d

Browse files
committed
feat: session builder optimization options
1 parent e31720d commit bfa791d

File tree

5 files changed

+419
-286
lines changed

5 files changed

+419
-286
lines changed

src/error.rs

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ pub enum Error {
3838
/// Error occurred when creating ONNX session options.
3939
#[error("Failed to create ONNX Runtime session options: {0}")]
4040
CreateSessionOptions(ErrorInternal),
41+
/// Failed to enable `onnxruntime-extensions` for session.
42+
#[error("Failed to enable `onnxruntime-extensions`: {0}")]
43+
EnableExtensions(ErrorInternal),
44+
#[error("Failed to add configuration entry to session builder: {0}")]
45+
AddSessionConfigEntry(ErrorInternal),
4146
/// Error occurred when creating an allocator from a [`crate::MemoryInfo`] struct while building a session.
4247
#[error("Failed to create allocator from memory info: {0}")]
4348
CreateAllocator(ErrorInternal),

src/session/builder/impl_commit.rs

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
#[cfg(feature = "fetch-models")]
2+
use std::fmt::Write;
3+
use std::{any::Any, marker::PhantomData, path::Path, ptr::NonNull, sync::Arc};
4+
5+
use super::SessionBuilder;
6+
#[cfg(feature = "fetch-models")]
7+
use crate::error::FetchModelError;
8+
use crate::{
9+
environment::get_environment,
10+
error::{Error, Result},
11+
execution_providers::apply_execution_providers,
12+
memory::Allocator,
13+
ortsys,
14+
session::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner}
15+
};
16+
17+
impl SessionBuilder {
18+
/// Downloads a pre-trained ONNX model from the given URL and builds the session.
19+
#[cfg(feature = "fetch-models")]
20+
#[cfg_attr(docsrs, doc(cfg(feature = "fetch-models")))]
21+
pub fn commit_from_url(self, model_url: impl AsRef<str>) -> Result<Session> {
22+
let mut download_dir = ort_sys::internal::dirs::cache_dir()
23+
.expect("could not determine cache directory")
24+
.join("models");
25+
if std::fs::create_dir_all(&download_dir).is_err() {
26+
download_dir = std::env::current_dir().expect("Failed to obtain current working directory");
27+
}
28+
29+
let url = model_url.as_ref();
30+
let model_filename = <sha2::Sha256 as sha2::Digest>::digest(url).into_iter().fold(String::new(), |mut s, b| {
31+
let _ = write!(&mut s, "{:02x}", b);
32+
s
33+
});
34+
let model_filepath = download_dir.join(model_filename);
35+
let downloaded_path = if model_filepath.exists() {
36+
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
37+
model_filepath
38+
} else {
39+
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
40+
41+
let resp = ureq::get(url).call().map_err(Box::new).map_err(FetchModelError::FetchError)?;
42+
43+
let len = resp
44+
.header("Content-Length")
45+
.and_then(|s| s.parse::<usize>().ok())
46+
.expect("Missing Content-Length header");
47+
tracing::info!(len, "Downloading {} bytes", len);
48+
49+
let mut reader = resp.into_reader();
50+
51+
let f = std::fs::File::create(&model_filepath).expect("Failed to create model file");
52+
let mut writer = std::io::BufWriter::new(f);
53+
54+
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(FetchModelError::IoError)?;
55+
if bytes_io_count == len as u64 {
56+
model_filepath
57+
} else {
58+
return Err(FetchModelError::CopyError {
59+
expected: len as u64,
60+
io: bytes_io_count
61+
}
62+
.into());
63+
}
64+
};
65+
66+
self.commit_from_file(downloaded_path)
67+
}
68+
69+
/// Loads an ONNX model from a file and builds the session.
70+
pub fn commit_from_file<P>(mut self, model_filepath_ref: P) -> Result<Session>
71+
where
72+
P: AsRef<Path>
73+
{
74+
let model_filepath = model_filepath_ref.as_ref();
75+
if !model_filepath.exists() {
76+
return Err(Error::FileDoesNotExist {
77+
filename: model_filepath.to_path_buf()
78+
});
79+
}
80+
81+
let model_path = crate::util::path_to_os_char(model_filepath);
82+
83+
let env = get_environment()?;
84+
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;
85+
86+
if env.has_global_threadpool {
87+
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
88+
}
89+
90+
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
91+
ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
92+
93+
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };
94+
95+
let allocator = match &self.memory_info {
96+
Some(info) => {
97+
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
98+
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
99+
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
100+
}
101+
None => Allocator::default()
102+
};
103+
104+
// Extract input and output properties
105+
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
106+
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
107+
let inputs = (0..num_input_nodes)
108+
.map(|i| dangerous::extract_input(session_ptr, &allocator, i))
109+
.collect::<Result<Vec<Input>>>()?;
110+
let outputs = (0..num_output_nodes)
111+
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
112+
.collect::<Result<Vec<Output>>>()?;
113+
114+
let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>);
115+
#[cfg(feature = "operator-libraries")]
116+
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>));
117+
let extras: Vec<Box<dyn Any>> = extras.collect();
118+
119+
Ok(Session {
120+
inner: Arc::new(SharedSessionInner {
121+
session_ptr,
122+
allocator,
123+
_extras: extras,
124+
_environment: env
125+
}),
126+
inputs,
127+
outputs
128+
})
129+
}
130+
131+
/// Load an ONNX graph from memory and commit the session
132+
/// For `.ort` models, we enable `session.use_ort_model_bytes_directly`.
133+
/// For more information, check [Load ORT format model from an in-memory byte array](https://onnxruntime.ai/docs/performance/model-optimizations/ort-format-models.html#load-ort-format-model-from-an-in-memory-byte-array).
134+
///
135+
/// If you wish to store the model bytes and the [`InMemorySession`] in the same struct, look for crates that
136+
/// facilitate creating self-referential structs, such as [`ouroboros`](https://github.com/joshua-maros/ouroboros).
137+
pub fn commit_from_memory_directly(mut self, model_bytes: &[u8]) -> Result<InMemorySession<'_>> {
138+
// Enable zero-copy deserialization for models in `.ort` format.
139+
self.add_config_entry("session.use_ort_model_bytes_directly", "1")?;
140+
self.add_config_entry("session.use_ort_model_bytes_for_initializers", "1")?;
141+
142+
let session = self.commit_from_memory(model_bytes)?;
143+
144+
Ok(InMemorySession { session, phantom: PhantomData })
145+
}
146+
147+
/// Load an ONNX graph from memory and commit the session.
148+
pub fn commit_from_memory(mut self, model_bytes: &[u8]) -> Result<Session> {
149+
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
150+
151+
let env = get_environment()?;
152+
apply_execution_providers(&self, env.execution_providers.iter().cloned())?;
153+
154+
if env.has_global_threadpool {
155+
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
156+
}
157+
158+
let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
159+
let model_data_length = model_bytes.len();
160+
ortsys![
161+
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession;
162+
nonNull(session_ptr)
163+
];
164+
165+
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };
166+
167+
let allocator = match &self.memory_info {
168+
Some(info) => {
169+
let mut allocator_ptr: *mut ort_sys::OrtAllocator = std::ptr::null_mut();
170+
ortsys![unsafe CreateAllocator(session_ptr.as_ptr(), info.ptr.as_ptr(), &mut allocator_ptr) -> Error::CreateAllocator; nonNull(allocator_ptr)];
171+
unsafe { Allocator::from_raw_unchecked(allocator_ptr) }
172+
}
173+
None => Allocator::default()
174+
};
175+
176+
// Extract input and output properties
177+
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
178+
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
179+
let inputs = (0..num_input_nodes)
180+
.map(|i| dangerous::extract_input(session_ptr, &allocator, i))
181+
.collect::<Result<Vec<Input>>>()?;
182+
let outputs = (0..num_output_nodes)
183+
.map(|i| dangerous::extract_output(session_ptr, &allocator, i))
184+
.collect::<Result<Vec<Output>>>()?;
185+
186+
let extras = self.operator_domains.drain(..).map(|d| Box::new(d) as Box<dyn Any>);
187+
#[cfg(feature = "operator-libraries")]
188+
let extras = extras.chain(self.custom_runtime_handles.drain(..).map(|d| Box::new(d) as Box<dyn Any>));
189+
let extras: Vec<Box<dyn Any>> = extras.collect();
190+
191+
let session = Session {
192+
inner: Arc::new(SharedSessionInner {
193+
session_ptr,
194+
allocator,
195+
_extras: extras,
196+
_environment: env
197+
}),
198+
inputs,
199+
outputs
200+
};
201+
Ok(session)
202+
}
203+
}
+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use super::SessionBuilder;
2+
use crate::Result;
3+
4+
// https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
5+
6+
impl SessionBuilder {
7+
/// Enable/disable the usage of prepacking.
8+
///
9+
/// This option is **enabled** by default.
10+
pub fn with_prepacking(mut self, enable: bool) -> Result<Self> {
11+
self.add_config_entry("session.disable_prepacking", if enable { "0" } else { "1" })?;
12+
Ok(self)
13+
}
14+
15+
/// Use allocators from the registered environment.
16+
///
17+
/// This option is **disabled** by default.
18+
pub fn with_env_allocators(mut self) -> Result<Self> {
19+
self.add_config_entry("session.use_env_allocators", "1")?;
20+
Ok(self)
21+
}
22+
23+
/// Enable flush-to-zero and denormal-as-zero.
24+
///
25+
/// This option is **disabled** by default, as it may hurt model accuracy.
26+
pub fn with_denormal_as_zero(mut self) -> Result<Self> {
27+
self.add_config_entry("session.set_denormal_as_zero", "1")?;
28+
Ok(self)
29+
}
30+
31+
/// Enable/disable fusion for quantized models in QDQ (QuantizeLinear/DequantizeLinear) format.
32+
///
33+
/// This option is **enabled** by default for all EPs except DirectML.
34+
pub fn with_quant_qdq(mut self, enable: bool) -> Result<Self> {
35+
self.add_config_entry("session.disable_quant_qdq", if enable { "0" } else { "1" })?;
36+
Ok(self)
37+
}
38+
39+
/// Enable/disable the optimization step removing double QDQ nodes.
40+
///
41+
/// This option is **enabled** by default.
42+
pub fn with_double_qdq_remover(mut self, enable: bool) -> Result<Self> {
43+
self.add_config_entry("session.disable_double_qdq_remover", if enable { "0" } else { "1" })?;
44+
Ok(self)
45+
}
46+
47+
/// Enable the removal of Q/DQ node pairs once all QDQ handling has been completed.
48+
///
49+
/// This option is **disabled** by default.
50+
pub fn with_qdq_cleanup(mut self) -> Result<Self> {
51+
self.add_config_entry("session.enable_quant_qdq_cleanup", "1")?;
52+
Ok(self)
53+
}
54+
55+
/// Enable fast GELU approximation.
56+
///
57+
/// This option is **disabled** by default, as it may hurt accuracy.
58+
pub fn with_approximate_gelu(mut self) -> Result<Self> {
59+
self.add_config_entry("optimization.enable_gelu_approximation", "1")?;
60+
Ok(self)
61+
}
62+
63+
/// Enable/disable ahead-of-time function inlining.
64+
///
65+
/// This option is **enabled** by default.
66+
pub fn with_aot_inlining(mut self, enable: bool) -> Result<Self> {
67+
self.add_config_entry("session.disable_aot_function_inlining", if enable { "0" } else { "1" })?;
68+
Ok(self)
69+
}
70+
71+
/// Accepts a comma-separated list of optimizers to disable.
72+
pub fn with_disabled_optimizers(mut self, optimizers: &str) -> Result<Self> {
73+
self.add_config_entry("optimization.disable_specified_optimizers", optimizers)?;
74+
Ok(self)
75+
}
76+
77+
/// Enable using device allocator for allocating initialized tensor memory.
78+
///
79+
/// This option is **disabled** by default.
80+
pub fn with_device_allocator_for_initializers(mut self) -> Result<Self> {
81+
self.add_config_entry("session.use_device_allocator_for_initializers", "1")?;
82+
Ok(self)
83+
}
84+
85+
/// Enable/disable allowing the inter-op threads to spin for a short period before blocking.
86+
///
87+
/// This option is **enabled** by defualt.
88+
pub fn with_inter_op_spinning(mut self, enable: bool) -> Result<Self> {
89+
self.add_config_entry("session.inter_op.allow_spinning", if enable { "1" } else { "0" })?;
90+
Ok(self)
91+
}
92+
93+
/// Enable/disable allowing the intra-op threads to spin for a short period before blocking.
94+
///
95+
/// This option is **enabled** by defualt.
96+
pub fn with_intra_op_spinning(mut self, enable: bool) -> Result<Self> {
97+
self.add_config_entry("session.intra_op.allow_spinning", if enable { "1" } else { "0" })?;
98+
Ok(self)
99+
}
100+
}

0 commit comments

Comments
 (0)