Skip to content

Commit f16aca6

Browse files
committed
Terminate underlying Rust future when AsyncThread is dropped.
Before this change, Lua GC was responsible to collect and destroy the future if `AsyncThread` dropped in yielded state. Now we will propagate "drop" event immediately so Lua GC need to only free the memory.
1 parent 9b45663 commit f16aca6

File tree

5 files changed

+94
-69
lines changed

5 files changed

+94
-69
lines changed

src/state.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,6 +2036,13 @@ impl Lua {
20362036
LightUserData(&ASYNC_POLL_PENDING as *const u8 as *mut std::os::raw::c_void)
20372037
}
20382038

2039+
#[cfg(feature = "async")]
2040+
#[inline(always)]
2041+
pub(crate) fn poll_terminate() -> LightUserData {
2042+
static ASYNC_POLL_TERMINATE: u8 = 0;
2043+
LightUserData(&ASYNC_POLL_TERMINATE as *const u8 as *mut std::os::raw::c_void)
2044+
}
2045+
20392046
/// Returns a weak reference to the Lua instance.
20402047
///
20412048
/// This is useful for creating a reference to the Lua instance that does not prevent it from

src/state/raw.rs

Lines changed: 20 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -613,46 +613,11 @@ impl RawLua {
613613
self.create_thread(func)
614614
}
615615

616-
/// Resets thread (coroutine) and returns it to the pool for later use.
616+
/// Returns the thread to the pool for later use.
617617
#[cfg(feature = "async")]
618618
pub(crate) unsafe fn recycle_thread(&self, thread: &mut Thread) {
619-
let thread_state = thread.1;
620619
let extra = &mut *self.extra.get();
621-
if extra.thread_pool.len() == extra.thread_pool.capacity() {
622-
#[cfg(feature = "lua54")]
623-
if ffi::lua_status(thread_state) != ffi::LUA_OK {
624-
// Close all to-be-closed variables without returning thread to the pool
625-
#[cfg(not(feature = "vendored"))]
626-
ffi::lua_resetthread(thread_state);
627-
#[cfg(feature = "vendored")]
628-
ffi::lua_closethread(thread_state, self.state());
629-
}
630-
return;
631-
}
632-
633-
let mut reset_ok = false;
634-
if ffi::lua_status(thread_state) == ffi::LUA_OK {
635-
if ffi::lua_gettop(thread_state) > 0 {
636-
ffi::lua_settop(thread_state, 0);
637-
}
638-
reset_ok = true;
639-
}
640-
641-
#[cfg(feature = "lua54")]
642-
if !reset_ok {
643-
#[cfg(not(feature = "vendored"))]
644-
let status = ffi::lua_resetthread(thread_state);
645-
#[cfg(feature = "vendored")]
646-
let status = ffi::lua_closethread(thread_state, self.state());
647-
reset_ok = status == ffi::LUA_OK;
648-
}
649-
#[cfg(feature = "luau")]
650-
if !reset_ok {
651-
ffi::lua_resetthread(thread_state);
652-
reset_ok = true;
653-
}
654-
655-
if reset_ok {
620+
if extra.thread_pool.len() < extra.thread_pool.capacity() {
656621
extra.thread_pool.push(thread.0.index);
657622
thread.0.drop = false; // Prevent thread from being garbage collected
658623
}
@@ -1244,7 +1209,7 @@ impl RawLua {
12441209
let rawlua = (*extra).raw_lua();
12451210

12461211
let func = &*(*upvalue).data;
1247-
let fut = func(rawlua, nargs);
1212+
let fut = Some(func(rawlua, nargs));
12481213
let extra = XRc::clone(&(*upvalue).extra);
12491214
let protect = !rawlua.unlikely_memory_error();
12501215
push_internal_userdata(state, AsyncPollUpvalue { data: fut, extra }, protect)?;
@@ -1262,20 +1227,27 @@ impl RawLua {
12621227

12631228
unsafe extern "C-unwind" fn poll_future(state: *mut ffi::lua_State) -> c_int {
12641229
let upvalue = get_userdata::<AsyncPollUpvalue>(state, ffi::lua_upvalueindex(1));
1265-
callback_error_ext(state, (*upvalue).extra.get(), true, |extra, _| {
1230+
callback_error_ext(state, (*upvalue).extra.get(), true, |extra, nargs| {
12661231
// Lua ensures that `LUA_MINSTACK` stack spaces are available (after pushing arguments)
12671232
// The lock must be already held as the future is polled
12681233
let rawlua = (*extra).raw_lua();
12691234

1235+
if nargs == 1 && ffi::lua_tolightuserdata(state, -1) == Lua::poll_terminate().0 {
1236+
// Destroy the future and terminate the Lua thread
1237+
(*upvalue).data.take();
1238+
ffi::lua_pushinteger(state, 0);
1239+
return Ok(1);
1240+
}
1241+
12701242
let fut = &mut (*upvalue).data;
12711243
let mut ctx = Context::from_waker(rawlua.waker());
1272-
match fut.as_mut().poll(&mut ctx) {
1273-
Poll::Pending => {
1244+
match fut.as_mut().map(|fut| fut.as_mut().poll(&mut ctx)) {
1245+
Some(Poll::Pending) => {
12741246
ffi::lua_pushnil(state);
12751247
ffi::lua_pushlightuserdata(state, Lua::poll_pending().0);
12761248
Ok(2)
12771249
}
1278-
Poll::Ready(nresults) => {
1250+
Some(Poll::Ready(nresults)) => {
12791251
match nresults? {
12801252
nresults if nresults < 3 => {
12811253
// Fast path for up to 2 results without creating a table
@@ -1293,6 +1265,7 @@ impl RawLua {
12931265
}
12941266
}
12951267
}
1268+
None => Err(Error::CallbackDestructed),
12961269
}
12971270
})
12981271
}
@@ -1338,8 +1311,8 @@ impl RawLua {
13381311
lua.load(
13391312
r#"
13401313
local poll = get_poll(...)
1314+
local nres, res, res2 = poll()
13411315
while true do
1342-
local nres, res, res2 = poll()
13431316
if nres ~= nil then
13441317
if nres == 0 then
13451318
return
@@ -1351,7 +1324,10 @@ impl RawLua {
13511324
return unpack(res, nres)
13521325
end
13531326
end
1354-
yield(res) -- `res` is a "pending" value
1327+
-- `res` is a "pending" value
1328+
-- `yield` can return a signal to drop the future that we should propagate
1329+
-- to the poller
1330+
nres, res, res2 = poll(yield(res))
13551331
end
13561332
"#,
13571333
)

src/thread.rs

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -305,40 +305,57 @@ impl Thread {
305305
pub fn reset(&self, func: Function) -> Result<()> {
306306
let lua = self.0.lua.lock();
307307
let thread_state = self.state();
308-
match self.status_inner(&lua) {
309-
ThreadStatusInner::Running => return Err(Error::runtime("cannot reset a running thread")),
310-
// Any Lua can reuse new or finished thread
311-
ThreadStatusInner::New(_) => unsafe { ffi::lua_settop(thread_state, 0) },
312-
ThreadStatusInner::Finished => {}
308+
unsafe {
309+
let status = self.status_inner(&lua);
310+
self.reset_inner(status)?;
311+
312+
// Push function to the top of the thread stack
313+
ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index);
314+
315+
#[cfg(feature = "luau")]
316+
{
317+
// Inherit `LUA_GLOBALSINDEX` from the main thread
318+
ffi::lua_xpush(lua.main_state(), thread_state, ffi::LUA_GLOBALSINDEX);
319+
ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
320+
}
321+
322+
Ok(())
323+
}
324+
}
325+
326+
unsafe fn reset_inner(&self, status: ThreadStatusInner) -> Result<()> {
327+
match status {
328+
ThreadStatusInner::New(_) => {
329+
// The thread is new, so we can just set the top to 0
330+
ffi::lua_settop(self.state(), 0);
331+
Ok(())
332+
}
333+
ThreadStatusInner::Running => Err(Error::runtime("cannot reset a running thread")),
334+
ThreadStatusInner::Finished => Ok(()),
313335
#[cfg(not(any(feature = "lua54", feature = "luau")))]
314-
_ => return Err(Error::runtime("cannot reset non-finished thread")),
336+
ThreadStatusInner::Yielded(_) | ThreadStatusInner::Error => {
337+
Err(Error::runtime("cannot reset non-finished thread"))
338+
}
315339
#[cfg(any(feature = "lua54", feature = "luau"))]
316-
_ => unsafe {
340+
ThreadStatusInner::Yielded(_) | ThreadStatusInner::Error => {
341+
let thread_state = self.state();
342+
317343
#[cfg(all(feature = "lua54", not(feature = "vendored")))]
318344
let status = ffi::lua_resetthread(thread_state);
319345
#[cfg(all(feature = "lua54", feature = "vendored"))]
320-
let status = ffi::lua_closethread(thread_state, lua.state());
346+
let status = {
347+
let lua = self.0.lua.lock();
348+
ffi::lua_closethread(thread_state, lua.state())
349+
};
321350
#[cfg(feature = "lua54")]
322351
if status != ffi::LUA_OK {
323352
return Err(pop_error(thread_state, status));
324353
}
325354
#[cfg(feature = "luau")]
326355
ffi::lua_resetthread(thread_state);
327-
},
328-
}
329-
330-
unsafe {
331-
// Push function to the top of the thread stack
332-
ffi::lua_xpush(lua.ref_thread(), thread_state, func.0.index);
333356

334-
#[cfg(feature = "luau")]
335-
{
336-
// Inherit `LUA_GLOBALSINDEX` from the main thread
337-
ffi::lua_xpush(lua.main_state(), thread_state, ffi::LUA_GLOBALSINDEX);
338-
ffi::lua_replace(thread_state, ffi::LUA_GLOBALSINDEX);
357+
Ok(())
339358
}
340-
341-
Ok(())
342359
}
343360
}
344361

@@ -505,8 +522,21 @@ impl<R> Drop for AsyncThread<R> {
505522
fn drop(&mut self) {
506523
if self.recycle {
507524
if let Some(lua) = self.thread.0.lua.try_lock() {
508-
// For Lua 5.4 this also closes all pending to-be-closed variables
509-
unsafe { lua.recycle_thread(&mut self.thread) };
525+
unsafe {
526+
let mut status = self.thread.status_inner(&lua);
527+
if matches!(status, ThreadStatusInner::Yielded(0)) {
528+
// The thread is dropped while yielded, resume it with the "terminate" signal
529+
ffi::lua_pushlightuserdata(self.thread.1, crate::Lua::poll_terminate().0);
530+
if let Ok((new_status, _)) = self.thread.resume_inner(&lua, 1) {
531+
status = new_status;
532+
}
533+
}
534+
535+
// For Lua 5.4 this also closes all pending to-be-closed variables
536+
if self.thread.reset_inner(status).is_ok() {
537+
lua.recycle_thread(&mut self.thread);
538+
}
539+
}
510540
}
511541
}
512542
}

src/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ pub(crate) type AsyncCallback =
6161
pub(crate) type AsyncCallbackUpvalue = Upvalue<AsyncCallback>;
6262

6363
#[cfg(feature = "async")]
64-
pub(crate) type AsyncPollUpvalue = Upvalue<BoxFuture<'static, Result<c_int>>>;
64+
pub(crate) type AsyncPollUpvalue = Upvalue<Option<BoxFuture<'static, Result<c_int>>>>;
6565

6666
/// Type to set next Lua VM action after executing interrupt or hook function.
6767
pub enum VmState {

tests/async.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use tokio::sync::Mutex;
99

1010
use mlua::{
1111
Error, Function, Lua, LuaOptions, MultiValue, ObjectLike, Result, StdLib, Table, UserData,
12-
UserDataMethods, Value,
12+
UserDataMethods, UserDataRef, Value,
1313
};
1414

1515
#[cfg(not(target_arch = "wasm32"))]
@@ -547,6 +547,7 @@ async fn test_async_thread_error() -> Result<()> {
547547

548548
#[tokio::test]
549549
async fn test_async_terminate() -> Result<()> {
550+
// Future captures `Lua` instance and dropped all together
550551
let mutex = Arc::new(Mutex::new(0u32));
551552
{
552553
let lua = Lua::new();
@@ -565,6 +566,17 @@ async fn test_async_terminate() -> Result<()> {
565566
}
566567
assert!(mutex.try_lock().is_ok());
567568

569+
// Future is dropped, but `Lua` instance is still alive
570+
let lua = Lua::new();
571+
let func = lua.create_async_function(move |_, mutex: UserDataRef<Arc<Mutex<u32>>>| async move {
572+
let _guard = mutex.lock().await;
573+
sleep_ms(100).await;
574+
Ok(())
575+
})?;
576+
let mutex2 = lua.create_any_userdata(mutex.clone())?;
577+
let _ = tokio::time::timeout(Duration::from_millis(30), func.call_async::<()>(mutex2)).await;
578+
assert!(mutex.try_lock().is_ok());
579+
568580
Ok(())
569581
}
570582

0 commit comments

Comments
 (0)