1
1
use std:: {
2
- cell:: UnsafeCell ,
3
2
ffi:: { self , CStr , CString } ,
4
- ptr,
5
- sync:: {
6
- atomic:: { AtomicPtr , Ordering } ,
7
- Arc
8
- }
3
+ ptr:: { self , NonNull } ,
4
+ sync:: { Arc , RwLock }
9
5
} ;
10
6
11
7
use ort_sys:: c_char;
@@ -20,12 +16,12 @@ use crate::{
20
16
} ;
21
17
22
18
struct EnvironmentSingleton {
23
- cell : UnsafeCell < Option < Arc < Environment > > >
19
+ lock : RwLock < Option < Arc < Environment > > >
24
20
}
25
21
26
22
unsafe impl Sync for EnvironmentSingleton { }
27
23
28
- static G_ENV : EnvironmentSingleton = EnvironmentSingleton { cell : UnsafeCell :: new ( None ) } ;
24
+ static G_ENV : EnvironmentSingleton = EnvironmentSingleton { lock : RwLock :: new ( None ) } ;
29
25
30
26
/// An `Environment` is a process-global structure, under which [`Session`](crate::Session)s are created.
31
27
///
@@ -41,14 +37,14 @@ static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::ne
41
37
#[ derive( Debug ) ]
42
38
pub struct Environment {
43
39
pub ( crate ) execution_providers : Vec < ExecutionProviderDispatch > ,
44
- pub ( crate ) env_ptr : AtomicPtr < ort_sys:: OrtEnv > ,
40
+ pub ( crate ) env_ptr : NonNull < ort_sys:: OrtEnv > ,
45
41
pub ( crate ) has_global_threadpool : bool
46
42
}
47
43
48
44
impl Environment {
49
45
/// Returns the underlying [`ort_sys::OrtEnv`] pointer.
50
46
pub fn ptr ( & self ) -> * mut ort_sys:: OrtEnv {
51
- self . env_ptr . load ( Ordering :: Relaxed )
47
+ self . env_ptr . as_ptr ( )
52
48
}
53
49
}
54
50
@@ -57,23 +53,22 @@ impl Drop for Environment {
57
53
fn drop ( & mut self ) {
58
54
debug ! ( "Releasing environment" ) ;
59
55
60
- let env_ptr: * mut ort_sys:: OrtEnv = * self . env_ptr . get_mut ( ) ;
61
-
62
- assert_ne ! ( env_ptr, std:: ptr:: null_mut( ) ) ;
63
- ortsys ! [ unsafe ReleaseEnv ( env_ptr) ] ;
56
+ ortsys ! [ unsafe ReleaseEnv ( self . env_ptr. as_ptr( ) ) ] ;
64
57
}
65
58
}
66
59
67
60
/// Gets a reference to the global environment, creating one if an environment has not been
68
61
/// [`commit`](EnvironmentBuilder::commit)ted yet.
69
- pub fn get_environment ( ) -> Result < & ' static Arc < Environment > > {
70
- if let Some ( c) = unsafe { & * G_ENV . cell . get ( ) } {
71
- Ok ( c)
62
+ pub fn get_environment ( ) -> Result < Arc < Environment > > {
63
+ let env = G_ENV . lock . read ( ) . expect ( "poisoned lock" ) ;
64
+ if let Some ( env) = env. as_ref ( ) {
65
+ Ok ( Arc :: clone ( env) )
72
66
} else {
73
- debug ! ( "Environment not yet initialized, creating a new one" ) ;
74
- EnvironmentBuilder :: new ( ) . commit ( ) ? ;
67
+ // drop our read lock so we dont deadlock when `commit` takes a write lock
68
+ drop ( env ) ;
75
69
76
- Ok ( unsafe { ( * G_ENV . cell . get ( ) ) . as_ref ( ) . unwrap_unchecked ( ) } )
70
+ debug ! ( "Environment not yet initialized, creating a new one" ) ;
71
+ Ok ( EnvironmentBuilder :: new ( ) . commit ( ) ?)
77
72
}
78
73
}
79
74
@@ -151,12 +146,7 @@ impl EnvironmentBuilder {
151
146
}
152
147
153
148
/// Commit the environment configuration and set the global environment.
154
- pub fn commit ( self ) -> Result < ( ) > {
155
- // drop global reference to previous environment
156
- if let Some ( env_arc) = unsafe { ( * G_ENV . cell . get ( ) ) . take ( ) } {
157
- drop ( env_arc) ;
158
- }
159
-
149
+ pub fn commit ( self ) -> Result < Arc < Environment > > {
160
150
let ( env_ptr, has_global_threadpool) = if let Some ( global_thread_pool) = self . global_thread_pool_options {
161
151
let mut env_ptr: * mut ort_sys:: OrtEnv = std:: ptr:: null_mut ( ) ;
162
152
let logging_function: ort_sys:: OrtLoggingFunction = Some ( custom_logger) ;
@@ -218,15 +208,20 @@ impl EnvironmentBuilder {
218
208
ortsys ! [ unsafe DisableTelemetryEvents ( env_ptr) -> Error :: CreateEnvironment ] ;
219
209
}
220
210
221
- unsafe {
222
- * G_ENV . cell . get ( ) = Some ( Arc :: new ( Environment {
223
- execution_providers : self . execution_providers ,
224
- env_ptr : AtomicPtr :: new ( env_ptr) ,
225
- has_global_threadpool
226
- } ) ) ;
227
- } ;
228
-
229
- Ok ( ( ) )
211
+ let mut env_lock = G_ENV . lock . write ( ) . expect ( "poisoned lock" ) ;
212
+ // drop global reference to previous environment
213
+ if let Some ( env_arc) = env_lock. take ( ) {
214
+ drop ( env_arc) ;
215
+ }
216
+ let env = Arc :: new ( Environment {
217
+ execution_providers : self . execution_providers ,
218
+ // we already asserted the env pointer is non-null in the `CreateEnvWithCustomLogger` call
219
+ env_ptr : unsafe { NonNull :: new_unchecked ( env_ptr) } ,
220
+ has_global_threadpool
221
+ } ) ;
222
+ env_lock. replace ( Arc :: clone ( & env) ) ;
223
+
224
+ Ok ( env)
230
225
}
231
226
}
232
227
@@ -316,16 +311,11 @@ mod tests {
316
311
use super :: * ;
317
312
318
313
fn is_env_initialized ( ) -> bool {
319
- unsafe { ( * G_ENV . cell . get ( ) ) . as_ref ( ) } . is_some ( )
320
- && !unsafe { ( * G_ENV . cell . get ( ) ) . as_ref ( ) }
321
- . unwrap_or_else ( || unreachable ! ( ) )
322
- . env_ptr
323
- . load ( Ordering :: Relaxed )
324
- . is_null ( )
314
+ G_ENV . lock . read ( ) . expect ( "poisoned lock" ) . is_some ( )
325
315
}
326
316
327
317
fn env_ptr ( ) -> Option < * mut ort_sys:: OrtEnv > {
328
- unsafe { ( * G_ENV . cell . get ( ) ) . as_ref ( ) } . map ( |f| f. env_ptr . load ( Ordering :: Relaxed ) )
318
+ ( * G_ENV . lock . read ( ) . expect ( "poisoned lock" ) ) . as_ref ( ) . map ( |f| f. env_ptr . as_ptr ( ) )
329
319
}
330
320
331
321
struct ConcurrentTestRun {
0 commit comments