Skip to content

Commit edcb219

Browse files
committed
fix: thread-safe environment initialization
1 parent 9f4527c commit edcb219

File tree

2 files changed

+37
-51
lines changed

2 files changed

+37
-51
lines changed

src/environment.rs

+32-42
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
use std::{
2-
cell::UnsafeCell,
32
ffi::{self, CStr, CString},
4-
ptr,
5-
sync::{
6-
atomic::{AtomicPtr, Ordering},
7-
Arc
8-
}
3+
ptr::{self, NonNull},
4+
sync::{Arc, RwLock}
95
};
106

117
use ort_sys::c_char;
@@ -20,12 +16,12 @@ use crate::{
2016
};
2117

2218
struct EnvironmentSingleton {
23-
cell: UnsafeCell<Option<Arc<Environment>>>
19+
lock: RwLock<Option<Arc<Environment>>>
2420
}
2521

2622
unsafe impl Sync for EnvironmentSingleton {}
2723

28-
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };
24+
static G_ENV: EnvironmentSingleton = EnvironmentSingleton { lock: RwLock::new(None) };
2925

3026
/// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created.
3127
///
@@ -41,14 +37,14 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::ne
4137
#[derive(Debug)]
4238
pub struct Environment {
4339
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
44-
pub(crate) env_ptr: AtomicPtr<ort_sys::OrtEnv>,
40+
pub(crate) env_ptr: NonNull<ort_sys::OrtEnv>,
4541
pub(crate) has_global_threadpool: bool
4642
}
4743

4844
impl Environment {
4945
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
5046
pub fn ptr(&self) -> *mut ort_sys::OrtEnv {
51-
self.env_ptr.load(Ordering::Relaxed)
47+
self.env_ptr.as_ptr()
5248
}
5349
}
5450

@@ -57,23 +53,22 @@ impl Drop for Environment {
5753
fn drop(&mut self) {
5854
debug!("Releasing environment");
5955

60-
let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut();
61-
62-
assert_ne!(env_ptr, std::ptr::null_mut());
63-
ortsys![unsafe ReleaseEnv(env_ptr)];
56+
ortsys![unsafe ReleaseEnv(self.env_ptr.as_ptr())];
6457
}
6558
}
6659

6760
/// Gets a reference to the global environment, creating one if an environment has not been
6861
/// [`commit`](EnvironmentBuilder::commit)ted yet.
69-
pub fn get_environment() -> Result<&'static Arc<Environment>> {
70-
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
71-
Ok(c)
62+
pub fn get_environment() -> Result<Arc<Environment>> {
63+
let env = G_ENV.lock.read().expect("poisoned lock");
64+
if let Some(env) = env.as_ref() {
65+
Ok(Arc::clone(env))
7266
} else {
73-
debug!("Environment not yet initialized, creating a new one");
74-
EnvironmentBuilder::new().commit()?;
67+
// drop our read lock so we dont deadlock when `commit` takes a write lock
68+
drop(env);
7569

76-
Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() })
70+
debug!("Environment not yet initialized, creating a new one");
71+
Ok(EnvironmentBuilder::new().commit()?)
7772
}
7873
}
7974

@@ -151,12 +146,7 @@ impl EnvironmentBuilder {
151146
}
152147

153148
/// Commit the environment configuration and set the global environment.
154-
pub fn commit(self) -> Result<()> {
155-
// drop global reference to previous environment
156-
if let Some(env_arc) = unsafe { (*G_ENV.cell.get()).take() } {
157-
drop(env_arc);
158-
}
159-
149+
pub fn commit(self) -> Result<Arc<Environment>> {
160150
let (env_ptr, has_global_threadpool) = if let Some(global_thread_pool) = self.global_thread_pool_options {
161151
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
162152
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
@@ -218,15 +208,20 @@ impl EnvironmentBuilder {
218208
ortsys![unsafe DisableTelemetryEvents(env_ptr) -> Error::CreateEnvironment];
219209
}
220210

221-
unsafe {
222-
*G_ENV.cell.get() = Some(Arc::new(Environment {
223-
execution_providers: self.execution_providers,
224-
env_ptr: AtomicPtr::new(env_ptr),
225-
has_global_threadpool
226-
}));
227-
};
228-
229-
Ok(())
211+
let mut env_lock = G_ENV.lock.write().expect("poisoned lock");
212+
// drop global reference to previous environment
213+
if let Some(env_arc) = env_lock.take() {
214+
drop(env_arc);
215+
}
216+
let env = Arc::new(Environment {
217+
execution_providers: self.execution_providers,
218+
// we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
219+
env_ptr: unsafe { NonNull::new_unchecked(env_ptr) },
220+
has_global_threadpool
221+
});
222+
env_lock.replace(Arc::clone(&env));
223+
224+
Ok(env)
230225
}
231226
}
232227

@@ -316,16 +311,11 @@ mod tests {
316311
use super::*;
317312

318313
fn is_env_initialized() -> bool {
319-
unsafe { (*G_ENV.cell.get()).as_ref() }.is_some()
320-
&& !unsafe { (*G_ENV.cell.get()).as_ref() }
321-
.unwrap_or_else(|| unreachable!())
322-
.env_ptr
323-
.load(Ordering::Relaxed)
324-
.is_null()
314+
G_ENV.lock.read().expect("poisoned lock").is_some()
325315
}
326316

327317
fn env_ptr() -> Option<*mut ort_sys::OrtEnv> {
328-
unsafe { (*G_ENV.cell.get()).as_ref() }.map(|f| f.env_ptr.load(Ordering::Relaxed))
318+
(*G_ENV.lock.read().expect("poisoned lock")).as_ref().map(|f| f.env_ptr.as_ptr())
329319
}
330320

331321
struct ConcurrentTestRun {

src/session/builder.rs

+5-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use std::{
88
path::Path,
99
ptr::{self, NonNull},
1010
rc::Rc,
11-
sync::{atomic::Ordering, Arc}
11+
sync::Arc
1212
};
1313

1414
use super::{dangerous, InMemorySession, Input, Output, Session, SharedSessionInner};
@@ -312,10 +312,8 @@ impl SessionBuilder {
312312
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
313313
}
314314

315-
let env_ptr = env.env_ptr.load(Ordering::Relaxed);
316-
317315
let mut session_ptr: *mut ort_sys::OrtSession = std::ptr::null_mut();
318-
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
316+
ortsys![unsafe CreateSession(env.env_ptr.as_ptr(), model_path.as_ptr(), self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession; nonNull(session_ptr)];
319317

320318
let session_ptr = unsafe { NonNull::new_unchecked(session_ptr) };
321319

@@ -348,7 +346,7 @@ impl SessionBuilder {
348346
session_ptr,
349347
allocator,
350348
_extras: extras,
351-
_environment: Arc::clone(env)
349+
_environment: env
352350
}),
353351
inputs,
354352
outputs
@@ -389,12 +387,10 @@ impl SessionBuilder {
389387
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr.as_ptr()) -> Error::CreateSessionOptions];
390388
}
391389

392-
let env_ptr = env.env_ptr.load(Ordering::Relaxed);
393-
394390
let model_data = model_bytes.as_ptr().cast::<std::ffi::c_void>();
395391
let model_data_length = model_bytes.len();
396392
ortsys![
397-
unsafe CreateSessionFromArray(env_ptr, model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession;
393+
unsafe CreateSessionFromArray(env.env_ptr.as_ptr(), model_data, model_data_length as _, self.session_options_ptr.as_ptr(), &mut session_ptr) -> Error::CreateSession;
398394
nonNull(session_ptr)
399395
];
400396

@@ -429,7 +425,7 @@ impl SessionBuilder {
429425
session_ptr,
430426
allocator,
431427
_extras: extras,
432-
_environment: Arc::clone(env)
428+
_environment: env
433429
}),
434430
inputs,
435431
outputs

0 commit comments

Comments
 (0)