Skip to content

Commit 63baf92

Browse files
committed
change to single event loop
1 parent 0053900 commit 63baf92

File tree

3 files changed

+239
-86
lines changed

3 files changed

+239
-86
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ pyo3 = ["dep:pyo3"]
2424
rustpython = ["dep:rustpython-vm", "dep:rustpython", "dep:rustpython-stdlib"]
2525

2626
[dependencies.pyo3]
27-
version = "0.26.0"
27+
version = "0.26"
2828
features = ["auto-initialize"]
2929
optional = true
3030

src/pyo3_runner.rs

Lines changed: 215 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,73 +2,167 @@
22
// © Copyright 2025, by Marco Mengelkoch
33
// Licensed under MIT License, see License file for more details
44
// git clone https://github.com/marcomq/async_py
5-
65
use crate::{print_path_for_python, CmdType, PyCommand};
76
use pyo3::{
87
exceptions::PyKeyError,
98
prelude::*,
10-
types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyString},
9+
types::{PyBool, PyCFunction, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple},
1110
IntoPyObjectExt,
1211
};
1312
use serde_json::Value;
13+
use std::collections::HashMap;
1414
use std::ffi::CString;
15-
use tokio::sync::{mpsc, oneshot};
15+
use tokio::sync::{mpsc, oneshot, Notify};
16+
17+
/// Holds the Python objects related to the async infrastructure.
18+
struct AsyncPyState {
19+
loop_obj: Py<PyAny>,
20+
result_queue: Py<PyAny>,
21+
make_callback_fn: Py<PyAny>,
22+
}
1623

1724
/// The main loop for the Python thread. This function is spawned in a new
1825
/// thread and is responsible for all Python interaction.
1926
pub(crate) async fn python_thread_main(mut receiver: mpsc::Receiver<PyCommand>) {
2027
Python::initialize();
21-
let globals = Python::attach(|py| PyDict::new(py).unbind());
22-
while let Some(mut cmd) = receiver.recv().await {
23-
Python::attach(|py| {
24-
let globals = globals.bind(py);
25-
let result = match std::mem::replace(&mut cmd.cmd_type, CmdType::Stop) {
26-
CmdType::RunCode(code) => {
27-
let c_code = CString::new(code).expect("CString::new failed");
28-
py.run(&c_code, Some(globals), None).map(|_| Value::Null)
29-
}
30-
CmdType::EvalCode(code) => {
31-
let c_code = CString::new(code).expect("CString::new failed");
32-
py.eval(&c_code, Some(globals), None)
33-
.and_then(|obj| py_any_to_json(&obj))
34-
}
35-
CmdType::RunFile(file) => handle_run_file(py, globals, file),
36-
CmdType::ReadVariable(var_name) => {
37-
get_py_object(globals, &var_name).and_then(|obj| py_any_to_json(&obj))
28+
29+
// State for async operations
30+
let mut pending: HashMap<usize, oneshot::Sender<Result<Value, String>>> = HashMap::new();
31+
let mut next_id: usize = 1;
32+
let mut async_state: Option<AsyncPyState> = None;
33+
34+
// Notifier to wake up the Rust select! loop from Python.
35+
let notify = std::sync::Arc::new(Notify::new());
36+
37+
// Create globals and inject the notifier callback.
38+
let globals = Python::attach(|py| -> PyResult<Py<PyDict>> {
39+
let globals = PyDict::new(py);
40+
let rust_notify_fn = {
41+
let notify = notify.clone();
42+
PyCFunction::new_closure(py, None, None, move |py, _| {
43+
notify.notify_one();
44+
Ok::<pyo3::Py<pyo3::PyAny>, PyErr>(py.py().None())
45+
})?
46+
.unbind()
47+
};
48+
globals.set_item("_rust_notify", rust_notify_fn)?;
49+
Ok(globals.unbind())
50+
})
51+
.expect("Failed to initialize Python globals");
52+
53+
loop {
54+
tokio::select! {
55+
// Branch 1: Wait for a new command from the channel.
56+
Some(cmd) = receiver.recv() => {
57+
if let CmdType::Stop = cmd.cmd_type {
58+
receiver.close();
59+
} else {
60+
handle_command(&globals, &mut async_state, &mut pending, &mut next_id, cmd);
3861
}
39-
CmdType::CallFunction { name, args } => {
40-
handle_call_function(py, globals, name, args)
62+
},
63+
// Branch 2: Wait for a notification from Python that a result is ready.
64+
_ = notify.notified(), if !pending.is_empty() && async_state.is_some() => {
65+
handle_notification(async_state.as_ref().unwrap(), &mut pending);
66+
}
67+
// If the receiver is closed and there are no more pending tasks, exit.
68+
else => {
69+
if receiver.is_closed() && pending.is_empty() {
70+
break;
4171
}
42-
CmdType::CallAsyncFunction { name, args } => {
43-
let result: PyResult<_> = (|| {
44-
let func = get_py_object(globals, &name)?;
45-
check_func_callable(&func, &name)?;
46-
Ok(func.unbind())
47-
})();
48-
49-
match result {
50-
Ok(func) => {
51-
py.detach(|| {
52-
tokio::spawn(handle_call_async_function(func, args, cmd.responder))
53-
});
54-
return; // The response is sent async, so we can return early.
72+
}
73+
}
74+
}
75+
}
76+
77+
/// Sets up the asyncio event loop and related infrastructure in Python.
78+
fn setup_async_infrastructure(py: Python, globals: &Bound<PyDict>) -> PyResult<AsyncPyState> {
79+
let code = r#"
80+
import asyncio, threading, queue, traceback
81+
_loop = asyncio.new_event_loop()
82+
def _run_loop():
83+
asyncio.set_event_loop(_loop)
84+
_loop.run_forever()
85+
_thread = threading.Thread(target=_run_loop, daemon=True)
86+
_thread.start()
87+
_result_queue = queue.Queue()
88+
def _async_done_cb(fut, id):
89+
try:
90+
res = fut.result()
91+
_result_queue.put({'id': id, 'ok': True, 'payload': res})
92+
except Exception as e:
93+
tb = traceback.format_exception_only(type(e), e)
94+
_result_queue.put({'id': id, 'ok': False, 'payload': ''.join(tb)})
95+
finally:
96+
_rust_notify()
97+
def _make_callback(id):
98+
def _cb(fut):
99+
_async_done_cb(fut, id)
100+
return _cb
101+
"#;
102+
let c_code = CString::new(code).expect("CString::new failed");
103+
py.run(&c_code, Some(globals), None)?;
104+
Ok(AsyncPyState {
105+
loop_obj: globals.get_item("_loop")?.unwrap().unbind(),
106+
result_queue: globals.get_item("_result_queue")?.unwrap().unbind(),
107+
make_callback_fn: globals.get_item("_make_callback")?.unwrap().unbind(),
108+
})
109+
}
110+
111+
/// Processes a command received from the Rust side.
112+
fn handle_command(
113+
globals: &Py<PyDict>,
114+
async_state: &mut Option<AsyncPyState>,
115+
pending: &mut HashMap<usize, oneshot::Sender<Result<Value, String>>>,
116+
next_id: &mut usize,
117+
mut cmd: PyCommand,
118+
) {
119+
Python::attach(|py| {
120+
let globals = globals.bind(py);
121+
let result = match std::mem::replace(&mut cmd.cmd_type, CmdType::Stop) {
122+
CmdType::RunCode(code) => {
123+
let c_code = CString::new(code).expect("CString::new failed");
124+
py.run(&c_code, Some(globals), None).map(|_| Value::Null)
125+
}
126+
CmdType::EvalCode(code) => {
127+
let c_code = CString::new(code).expect("CString::new failed");
128+
py.eval(&c_code, Some(globals), None)
129+
.and_then(|obj| py_any_to_json(&obj))
130+
}
131+
CmdType::RunFile(file) => handle_run_file(py, globals, file),
132+
CmdType::ReadVariable(var_name) => {
133+
get_py_object(globals, &var_name).and_then(|obj| py_any_to_json(&obj))
134+
}
135+
CmdType::CallFunction { name, args } => handle_call_function(py, globals, name, args),
136+
CmdType::CallAsyncFunction { name, args } => {
137+
let id = *next_id;
138+
*next_id += 1;
139+
140+
// Initialize async infrastructure on first use.
141+
if async_state.is_none() {
142+
match setup_async_infrastructure(py, globals) {
143+
Ok(state) => *async_state = Some(state),
144+
Err(e) => {
145+
let _ = cmd.responder.send(Err(e.to_string()));
146+
return;
55147
}
56-
Err(e) => Err(e),
57148
}
58149
}
59-
CmdType::Stop => return receiver.close(),
60-
};
61-
62-
// Convert PyErr to a string representation to avoid exposing it outside this module.
63-
let response = match result {
64-
Ok(value) => Ok(value),
65-
Err(e) => Err(e.to_string()),
66-
};
67-
let _ = cmd.responder.send(response);
68-
});
69-
// After the loop, we can send a final confirmation for the Stop command if needed,
70-
// but the current implementation in lib.rs handles the channel closing.
71-
}
150+
let state = async_state.as_ref().unwrap();
151+
pending.insert(id, cmd.responder);
152+
153+
if let Err(e) = handle_call_async_function(py, globals, state, id, &name, args) {
154+
if let Some(tx) = pending.remove(&id) {
155+
let _ = tx.send(Err(e));
156+
}
157+
}
158+
return; // Response is sent async, so we return early.
159+
}
160+
CmdType::Stop => return, // Handled in the select! loop.
161+
};
162+
163+
let response = result.map_err(|e| e.to_string());
164+
let _ = cmd.responder.send(response);
165+
});
72166
}
73167

74168
/// Resolves a potentially dot-separated Python object name from the globals dictionary.
@@ -103,7 +197,7 @@ fn check_func_callable(func: &Bound<PyAny>, name: &str) -> PyResult<()> {
103197

104198
fn handle_run_file(
105199
py: Python,
106-
globals: &pyo3::Bound<'_, PyDict>,
200+
globals: &Bound<'_, PyDict>,
107201
file: std::path::PathBuf,
108202
) -> PyResult<Value> {
109203
let code = format!(
@@ -123,7 +217,7 @@ with open({}, 'r') as f:
123217
/// Handles the `CallFunction` command.
124218
fn handle_call_function(
125219
py: Python,
126-
globals: &pyo3::Bound<'_, PyDict>,
220+
globals: &Bound<'_, PyDict>,
127221
name: String,
128222
args: Vec<Value>,
129223
) -> PyResult<Value> {
@@ -142,30 +236,82 @@ fn vec_to_py_tuple<'py>(
142236
.into_iter()
143237
.map(|v| json_value_to_pyobject(*py, v))
144238
.collect::<PyResult<Vec<_>>>()?;
145-
pyo3::types::PyTuple::new(*py, py_args)
239+
PyTuple::new(*py, py_args)
146240
}
147241

148242
/// Handles the `CallAsyncFunction` command.
149-
async fn handle_call_async_function(
150-
func: Py<PyAny>,
243+
fn handle_call_async_function(
244+
py: Python,
245+
globals: &Bound<PyDict>,
246+
async_state: &AsyncPyState,
247+
id: usize,
248+
name: &str,
151249
args: Vec<Value>,
152-
responder: oneshot::Sender<Result<Value, String>>,
250+
) -> Result<(), String> {
251+
let func = get_py_object(globals, name)
252+
.and_then(|f| check_func_callable(&f, name).map(|_| f))
253+
.map_err(|e| e.to_string())?;
254+
255+
let t_args = vec_to_py_tuple(&py, args).map_err(|e| e.to_string())?;
256+
let coroutine = func.call1(t_args).map_err(|e| e.to_string())?;
257+
258+
let asyncio = py.import("asyncio").map_err(|e| e.to_string())?;
259+
let run_threadsafe = asyncio
260+
.getattr("run_coroutine_threadsafe")
261+
.map_err(|e| e.to_string())?;
262+
263+
let fut = run_threadsafe
264+
.call1((coroutine, async_state.loop_obj.bind(py)))
265+
.map_err(|e| e.to_string())?;
266+
267+
let cb = async_state
268+
.make_callback_fn
269+
.bind(py)
270+
.call1((id,))
271+
.map_err(|e| e.to_string())?;
272+
273+
fut.call_method1("add_done_callback", (cb,))
274+
.map_err(|e| e.to_string())?;
275+
276+
Ok(())
277+
}
278+
279+
/// Drains the Python-side result queue and completes any pending responders.
280+
fn handle_notification(
281+
async_state: &AsyncPyState,
282+
pending: &mut HashMap<usize, oneshot::Sender<Result<Value, String>>>,
153283
) {
154-
let result = Python::attach(|py| {
155-
// Note: This approach creates a new asyncio event loop for each async call,
156-
// which can be inefficient for a high volume of calls. It ensures that each
157-
// call is isolated but does not share an event loop for concurrent execution
158-
// on the Python side.
159-
let func = func.bind(py);
160-
let t_args = vec_to_py_tuple(&py, args)?;
161-
let coroutine = func.call1(t_args)?;
162-
163-
let asyncio = py.import("asyncio")?;
164-
let result = asyncio.call_method1("run", (coroutine,))?;
165-
166-
py_any_to_json(&result)
284+
Python::attach(|py| {
285+
let get_nowait = match async_state.result_queue.bind(py).getattr("get_nowait") {
286+
Ok(f) => f,
287+
Err(_) => return, // Should not happen if setup is correct
288+
};
289+
while let Ok(item) = get_nowait.call0() {
290+
if let Ok(dict) = item.downcast::<PyDict>() {
291+
let id = dict
292+
.get_item("id")
293+
.unwrap()
294+
.unwrap()
295+
.extract::<usize>()
296+
.unwrap();
297+
if let Some(tx) = pending.remove(&id) {
298+
let ok = dict
299+
.get_item("ok")
300+
.unwrap()
301+
.unwrap()
302+
.extract::<bool>()
303+
.unwrap();
304+
let payload = dict.get_item("payload").unwrap().unwrap();
305+
let res = if ok {
306+
py_any_to_json(&payload).map_err(|e| e.to_string())
307+
} else {
308+
Err(payload.to_string())
309+
};
310+
let _ = tx.send(res);
311+
}
312+
}
313+
}
167314
});
168-
let _ = responder.send(result.map_err(|e| e.to_string()));
169315
}
170316

171317
/// Recursively converts a Python object to a `serde_json::Value`.

0 commit comments

Comments
 (0)