22// © Copyright 2025, by Marco Mengelkoch
33// Licensed under MIT License, see License file for more details
44// git clone https://github.com/marcomq/async_py
5-
65use crate :: { print_path_for_python, CmdType , PyCommand } ;
76use pyo3:: {
87 exceptions:: PyKeyError ,
98 prelude:: * ,
10- types:: { PyBool , PyDict , PyFloat , PyInt , PyList , PyString } ,
9+ types:: { PyBool , PyCFunction , PyDict , PyFloat , PyInt , PyList , PyString , PyTuple } ,
1110 IntoPyObjectExt ,
1211} ;
1312use serde_json:: Value ;
13+ use std:: collections:: HashMap ;
1414use std:: ffi:: CString ;
15- use tokio:: sync:: { mpsc, oneshot} ;
15+ use tokio:: sync:: { mpsc, oneshot, Notify } ;
16+
17+ /// Holds the Python objects related to the async infrastructure.
18+ struct AsyncPyState {
19+ loop_obj : Py < PyAny > ,
20+ result_queue : Py < PyAny > ,
21+ make_callback_fn : Py < PyAny > ,
22+ }
1623
1724/// The main loop for the Python thread. This function is spawned in a new
1825/// thread and is responsible for all Python interaction.
1926pub ( crate ) async fn python_thread_main ( mut receiver : mpsc:: Receiver < PyCommand > ) {
2027 Python :: initialize ( ) ;
21- let globals = Python :: attach ( |py| PyDict :: new ( py) . unbind ( ) ) ;
22- while let Some ( mut cmd) = receiver. recv ( ) . await {
23- Python :: attach ( |py| {
24- let globals = globals. bind ( py) ;
25- let result = match std:: mem:: replace ( & mut cmd. cmd_type , CmdType :: Stop ) {
26- CmdType :: RunCode ( code) => {
27- let c_code = CString :: new ( code) . expect ( "CString::new failed" ) ;
28- py. run ( & c_code, Some ( globals) , None ) . map ( |_| Value :: Null )
29- }
30- CmdType :: EvalCode ( code) => {
31- let c_code = CString :: new ( code) . expect ( "CString::new failed" ) ;
32- py. eval ( & c_code, Some ( globals) , None )
33- . and_then ( |obj| py_any_to_json ( & obj) )
34- }
35- CmdType :: RunFile ( file) => handle_run_file ( py, globals, file) ,
36- CmdType :: ReadVariable ( var_name) => {
37- get_py_object ( globals, & var_name) . and_then ( |obj| py_any_to_json ( & obj) )
28+
29+ // State for async operations
30+ let mut pending: HashMap < usize , oneshot:: Sender < Result < Value , String > > > = HashMap :: new ( ) ;
31+ let mut next_id: usize = 1 ;
32+ let mut async_state: Option < AsyncPyState > = None ;
33+
34+ // Notifier to wake up the Rust select! loop from Python.
35+ let notify = std:: sync:: Arc :: new ( Notify :: new ( ) ) ;
36+
37+ // Create globals and inject the notifier callback.
38+ let globals = Python :: attach ( |py| -> PyResult < Py < PyDict > > {
39+ let globals = PyDict :: new ( py) ;
40+ let rust_notify_fn = {
41+ let notify = notify. clone ( ) ;
42+ PyCFunction :: new_closure ( py, None , None , move |py, _| {
43+ notify. notify_one ( ) ;
44+ Ok :: < pyo3:: Py < pyo3:: PyAny > , PyErr > ( py. py ( ) . None ( ) )
45+ } ) ?
46+ . unbind ( )
47+ } ;
48+ globals. set_item ( "_rust_notify" , rust_notify_fn) ?;
49+ Ok ( globals. unbind ( ) )
50+ } )
51+ . expect ( "Failed to initialize Python globals" ) ;
52+
53+ loop {
54+ tokio:: select! {
55+ // Branch 1: Wait for a new command from the channel.
56+ Some ( cmd) = receiver. recv( ) => {
57+ if let CmdType :: Stop = cmd. cmd_type {
58+ receiver. close( ) ;
59+ } else {
60+ handle_command( & globals, & mut async_state, & mut pending, & mut next_id, cmd) ;
3861 }
39- CmdType :: CallFunction { name, args } => {
40- handle_call_function ( py, globals, name, args)
62+ } ,
63+ // Branch 2: Wait for a notification from Python that a result is ready.
64+ _ = notify. notified( ) , if !pending. is_empty( ) && async_state. is_some( ) => {
65+ handle_notification( async_state. as_ref( ) . unwrap( ) , & mut pending) ;
66+ }
67+ // If the receiver is closed and there are no more pending tasks, exit.
68+ else => {
69+ if receiver. is_closed( ) && pending. is_empty( ) {
70+ break ;
4171 }
42- CmdType :: CallAsyncFunction { name, args } => {
43- let result: PyResult < _ > = ( || {
44- let func = get_py_object ( globals, & name) ?;
45- check_func_callable ( & func, & name) ?;
46- Ok ( func. unbind ( ) )
47- } ) ( ) ;
48-
49- match result {
50- Ok ( func) => {
51- py. detach ( || {
52- tokio:: spawn ( handle_call_async_function ( func, args, cmd. responder ) )
53- } ) ;
54- return ; // The response is sent async, so we can return early.
72+ }
73+ }
74+ }
75+ }
76+
77+ /// Sets up the asyncio event loop and related infrastructure in Python.
78+ fn setup_async_infrastructure ( py : Python , globals : & Bound < PyDict > ) -> PyResult < AsyncPyState > {
79+ let code = r#"
80+ import asyncio, threading, queue, traceback
81+ _loop = asyncio.new_event_loop()
82+ def _run_loop():
83+ asyncio.set_event_loop(_loop)
84+ _loop.run_forever()
85+ _thread = threading.Thread(target=_run_loop, daemon=True)
86+ _thread.start()
87+ _result_queue = queue.Queue()
88+ def _async_done_cb(fut, id):
89+ try:
90+ res = fut.result()
91+ _result_queue.put({'id': id, 'ok': True, 'payload': res})
92+ except Exception as e:
93+ tb = traceback.format_exception_only(type(e), e)
94+ _result_queue.put({'id': id, 'ok': False, 'payload': ''.join(tb)})
95+ finally:
96+ _rust_notify()
97+ def _make_callback(id):
98+ def _cb(fut):
99+ _async_done_cb(fut, id)
100+ return _cb
101+ "# ;
102+ let c_code = CString :: new ( code) . expect ( "CString::new failed" ) ;
103+ py. run ( & c_code, Some ( globals) , None ) ?;
104+ Ok ( AsyncPyState {
105+ loop_obj : globals. get_item ( "_loop" ) ?. unwrap ( ) . unbind ( ) ,
106+ result_queue : globals. get_item ( "_result_queue" ) ?. unwrap ( ) . unbind ( ) ,
107+ make_callback_fn : globals. get_item ( "_make_callback" ) ?. unwrap ( ) . unbind ( ) ,
108+ } )
109+ }
110+
111+ /// Processes a command received from the Rust side.
112+ fn handle_command (
113+ globals : & Py < PyDict > ,
114+ async_state : & mut Option < AsyncPyState > ,
115+ pending : & mut HashMap < usize , oneshot:: Sender < Result < Value , String > > > ,
116+ next_id : & mut usize ,
117+ mut cmd : PyCommand ,
118+ ) {
119+ Python :: attach ( |py| {
120+ let globals = globals. bind ( py) ;
121+ let result = match std:: mem:: replace ( & mut cmd. cmd_type , CmdType :: Stop ) {
122+ CmdType :: RunCode ( code) => {
123+ let c_code = CString :: new ( code) . expect ( "CString::new failed" ) ;
124+ py. run ( & c_code, Some ( globals) , None ) . map ( |_| Value :: Null )
125+ }
126+ CmdType :: EvalCode ( code) => {
127+ let c_code = CString :: new ( code) . expect ( "CString::new failed" ) ;
128+ py. eval ( & c_code, Some ( globals) , None )
129+ . and_then ( |obj| py_any_to_json ( & obj) )
130+ }
131+ CmdType :: RunFile ( file) => handle_run_file ( py, globals, file) ,
132+ CmdType :: ReadVariable ( var_name) => {
133+ get_py_object ( globals, & var_name) . and_then ( |obj| py_any_to_json ( & obj) )
134+ }
135+ CmdType :: CallFunction { name, args } => handle_call_function ( py, globals, name, args) ,
136+ CmdType :: CallAsyncFunction { name, args } => {
137+ let id = * next_id;
138+ * next_id += 1 ;
139+
140+ // Initialize async infrastructure on first use.
141+ if async_state. is_none ( ) {
142+ match setup_async_infrastructure ( py, globals) {
143+ Ok ( state) => * async_state = Some ( state) ,
144+ Err ( e) => {
145+ let _ = cmd. responder . send ( Err ( e. to_string ( ) ) ) ;
146+ return ;
55147 }
56- Err ( e) => Err ( e) ,
57148 }
58149 }
59- CmdType :: Stop => return receiver. close ( ) ,
60- } ;
61-
62- // Convert PyErr to a string representation to avoid exposing it outside this module.
63- let response = match result {
64- Ok ( value) => Ok ( value) ,
65- Err ( e) => Err ( e. to_string ( ) ) ,
66- } ;
67- let _ = cmd. responder . send ( response) ;
68- } ) ;
69- // After the loop, we can send a final confirmation for the Stop command if needed,
70- // but the current implementation in lib.rs handles the channel closing.
71- }
150+ let state = async_state. as_ref ( ) . unwrap ( ) ;
151+ pending. insert ( id, cmd. responder ) ;
152+
153+ if let Err ( e) = handle_call_async_function ( py, globals, state, id, & name, args) {
154+ if let Some ( tx) = pending. remove ( & id) {
155+ let _ = tx. send ( Err ( e) ) ;
156+ }
157+ }
158+ return ; // Response is sent async, so we return early.
159+ }
160+ CmdType :: Stop => return , // Handled in the select! loop.
161+ } ;
162+
163+ let response = result. map_err ( |e| e. to_string ( ) ) ;
164+ let _ = cmd. responder . send ( response) ;
165+ } ) ;
72166}
73167
74168/// Resolves a potentially dot-separated Python object name from the globals dictionary.
@@ -103,7 +197,7 @@ fn check_func_callable(func: &Bound<PyAny>, name: &str) -> PyResult<()> {
103197
104198fn handle_run_file (
105199 py : Python ,
106- globals : & pyo3 :: Bound < ' _ , PyDict > ,
200+ globals : & Bound < ' _ , PyDict > ,
107201 file : std:: path:: PathBuf ,
108202) -> PyResult < Value > {
109203 let code = format ! (
@@ -123,7 +217,7 @@ with open({}, 'r') as f:
123217/// Handles the `CallFunction` command.
124218fn handle_call_function (
125219 py : Python ,
126- globals : & pyo3 :: Bound < ' _ , PyDict > ,
220+ globals : & Bound < ' _ , PyDict > ,
127221 name : String ,
128222 args : Vec < Value > ,
129223) -> PyResult < Value > {
@@ -142,30 +236,82 @@ fn vec_to_py_tuple<'py>(
142236 . into_iter ( )
143237 . map ( |v| json_value_to_pyobject ( * py, v) )
144238 . collect :: < PyResult < Vec < _ > > > ( ) ?;
145- pyo3 :: types :: PyTuple :: new ( * py, py_args)
239+ PyTuple :: new ( * py, py_args)
146240}
147241
148242/// Handles the `CallAsyncFunction` command.
149- async fn handle_call_async_function (
150- func : Py < PyAny > ,
243+ fn handle_call_async_function (
244+ py : Python ,
245+ globals : & Bound < PyDict > ,
246+ async_state : & AsyncPyState ,
247+ id : usize ,
248+ name : & str ,
151249 args : Vec < Value > ,
152- responder : oneshot:: Sender < Result < Value , String > > ,
250+ ) -> Result < ( ) , String > {
251+ let func = get_py_object ( globals, name)
252+ . and_then ( |f| check_func_callable ( & f, name) . map ( |_| f) )
253+ . map_err ( |e| e. to_string ( ) ) ?;
254+
255+ let t_args = vec_to_py_tuple ( & py, args) . map_err ( |e| e. to_string ( ) ) ?;
256+ let coroutine = func. call1 ( t_args) . map_err ( |e| e. to_string ( ) ) ?;
257+
258+ let asyncio = py. import ( "asyncio" ) . map_err ( |e| e. to_string ( ) ) ?;
259+ let run_threadsafe = asyncio
260+ . getattr ( "run_coroutine_threadsafe" )
261+ . map_err ( |e| e. to_string ( ) ) ?;
262+
263+ let fut = run_threadsafe
264+ . call1 ( ( coroutine, async_state. loop_obj . bind ( py) ) )
265+ . map_err ( |e| e. to_string ( ) ) ?;
266+
267+ let cb = async_state
268+ . make_callback_fn
269+ . bind ( py)
270+ . call1 ( ( id, ) )
271+ . map_err ( |e| e. to_string ( ) ) ?;
272+
273+ fut. call_method1 ( "add_done_callback" , ( cb, ) )
274+ . map_err ( |e| e. to_string ( ) ) ?;
275+
276+ Ok ( ( ) )
277+ }
278+
279+ /// Drains the Python-side result queue and completes any pending responders.
280+ fn handle_notification (
281+ async_state : & AsyncPyState ,
282+ pending : & mut HashMap < usize , oneshot:: Sender < Result < Value , String > > > ,
153283) {
154- let result = Python :: attach ( |py| {
155- // Note: This approach creates a new asyncio event loop for each async call,
156- // which can be inefficient for a high volume of calls. It ensures that each
157- // call is isolated but does not share an event loop for concurrent execution
158- // on the Python side.
159- let func = func. bind ( py) ;
160- let t_args = vec_to_py_tuple ( & py, args) ?;
161- let coroutine = func. call1 ( t_args) ?;
162-
163- let asyncio = py. import ( "asyncio" ) ?;
164- let result = asyncio. call_method1 ( "run" , ( coroutine, ) ) ?;
165-
166- py_any_to_json ( & result)
284+ Python :: attach ( |py| {
285+ let get_nowait = match async_state. result_queue . bind ( py) . getattr ( "get_nowait" ) {
286+ Ok ( f) => f,
287+ Err ( _) => return , // Should not happen if setup is correct
288+ } ;
289+ while let Ok ( item) = get_nowait. call0 ( ) {
290+ if let Ok ( dict) = item. downcast :: < PyDict > ( ) {
291+ let id = dict
292+ . get_item ( "id" )
293+ . unwrap ( )
294+ . unwrap ( )
295+ . extract :: < usize > ( )
296+ . unwrap ( ) ;
297+ if let Some ( tx) = pending. remove ( & id) {
298+ let ok = dict
299+ . get_item ( "ok" )
300+ . unwrap ( )
301+ . unwrap ( )
302+ . extract :: < bool > ( )
303+ . unwrap ( ) ;
304+ let payload = dict. get_item ( "payload" ) . unwrap ( ) . unwrap ( ) ;
305+ let res = if ok {
306+ py_any_to_json ( & payload) . map_err ( |e| e. to_string ( ) )
307+ } else {
308+ Err ( payload. to_string ( ) )
309+ } ;
310+ let _ = tx. send ( res) ;
311+ }
312+ }
313+ }
167314 } ) ;
168- let _ = responder. send ( result. map_err ( |e| e. to_string ( ) ) ) ;
169315}
170316
171317/// Recursively converts a Python object to a `serde_json::Value`.
0 commit comments