From fac72c82b09ea1501a2c2eb6c00fd8cd89363392 Mon Sep 17 00:00:00 2001 From: Raymond Chen Date: Tue, 12 Sep 2023 11:25:16 -0700 Subject: [PATCH] Add `resume_agile` to allow coroutine to resume in any apartment (#1356) --- strings/base_coroutine_foundation.h | 48 ++++++++++------ strings/base_coroutine_threadpool.h | 17 ++---- strings/base_meta.h | 19 +++++++ test/test/await_adapter.cpp | 85 ++++++++++++++++++++--------- 4 files changed, 112 insertions(+), 57 deletions(-) diff --git a/strings/base_coroutine_foundation.h b/strings/base_coroutine_foundation.h index 4c467f921..ba61cd49a 100644 --- a/strings/base_coroutine_foundation.h +++ b/strings/base_coroutine_foundation.h @@ -99,44 +99,49 @@ namespace winrt::impl return async.GetResults(); } - template - struct disconnect_aware_handler + struct ignore_apartment_context {}; + + template + struct disconnect_aware_handler : private std::conditional_t { disconnect_aware_handler(Awaiter* awaiter, coroutine_handle<> handle) noexcept : m_awaiter(awaiter), m_handle(handle) { } - disconnect_aware_handler(disconnect_aware_handler&& other) noexcept - : m_context(std::move(other.m_context)) - , m_awaiter(std::exchange(other.m_awaiter, {})) - , m_handle(std::exchange(other.m_handle, {})) { } + disconnect_aware_handler(disconnect_aware_handler&& other) = default; ~disconnect_aware_handler() { - if (m_handle) Complete(); + if (m_handle.value) Complete(); } template void operator()(Async&&, Windows::Foundation::AsyncStatus status) { - m_awaiter->status = status; + m_awaiter.value->status = status; Complete(); } private: - resume_apartment_context m_context; - Awaiter* m_awaiter; - coroutine_handle<> m_handle; + movable_primitive m_awaiter; + movable_primitive, nullptr> m_handle; void Complete() { - if (m_awaiter->suspending.exchange(false, std::memory_order_release)) + if (m_awaiter.value->suspending.exchange(false, std::memory_order_release)) { - m_handle = nullptr; // resumption deferred to await_suspend + m_handle.value = nullptr; // resumption deferred to await_suspend } else { - auto handle = std::exchange(m_handle, {}); - if (!resume_apartment(m_context, handle, &m_awaiter->failure)) + auto handle = m_handle.detach(); + if constexpr (preserve_context) + { + if (!resume_apartment(*this, handle, &m_awaiter.value->failure)) + { + handle.resume(); + } + } + else { handle.resume(); } @@ -145,7 +150,7 @@ namespace winrt::impl }; #ifdef WINRT_IMPL_COROUTINES - template + template struct await_adapter : cancellable_awaiter> { await_adapter(Async const& async) : async(async) { } @@ -185,7 +190,7 @@ namespace winrt::impl private: bool register_completed_callback(coroutine_handle<> handle) { - async.Completed(disconnect_aware_handler(this, handle)); + async.Completed(disconnect_aware_handler(this, handle)); return suspending.exchange(false, std::memory_order_acquire); } @@ -249,6 +254,15 @@ namespace winrt::impl } #ifdef WINRT_IMPL_COROUTINES +WINRT_EXPORT namespace winrt +{ + template>> + inline impl::await_adapter resume_agile(Async const& async) + { + return { async }; + }; +} + WINRT_EXPORT namespace winrt::Windows::Foundation { inline impl::await_adapter operator co_await(IAsyncAction const& async) diff --git a/strings/base_coroutine_threadpool.h b/strings/base_coroutine_threadpool.h index e3a283e33..0faaa1acd 100644 --- a/strings/base_coroutine_threadpool.h +++ b/strings/base_coroutine_threadpool.h @@ -52,23 +52,14 @@ namespace winrt::impl { resume_apartment_context() = default; resume_apartment_context(std::nullptr_t) : m_context(nullptr), m_context_type(-1) {} - resume_apartment_context(resume_apartment_context const&) = default; - resume_apartment_context(resume_apartment_context&& other) noexcept : - m_context(std::move(other.m_context)), m_context_type(std::exchange(other.m_context_type, -1)) {} - resume_apartment_context& operator=(resume_apartment_context const&) = default; - resume_apartment_context& operator=(resume_apartment_context&& other) noexcept - { - m_context = std::move(other.m_context); - m_context_type = std::exchange(other.m_context_type, -1); - return *this; - } + bool valid() const noexcept { - return m_context_type >= 0; + return m_context_type.value >= 0; } com_ptr m_context = try_capture(WINRT_IMPL_CoGetObjectContext); - int32_t m_context_type = get_apartment_type().first; + movable_primitive m_context_type = get_apartment_type().first; }; inline int32_t __stdcall resume_apartment_callback(com_callback_args* args) noexcept @@ -124,7 +115,7 @@ namespace winrt::impl { return false; } - else if (context.m_context_type == 1 /* APTTYPE_MTA */) + else if (context.m_context_type.value == 1 /* APTTYPE_MTA */) { resume_background(handle); return true; diff --git a/strings/base_meta.h b/strings/base_meta.h index f474fced4..2c1796e9c 100644 --- a/strings/base_meta.h +++ b/strings/base_meta.h @@ -193,6 +193,25 @@ namespace winrt::impl } } + template + struct movable_primitive + { + T value = empty_value; + movable_primitive() = default; + movable_primitive(T const& init) : value(init) {} + movable_primitive(movable_primitive const&) = default; + movable_primitive(movable_primitive&& other) : + value(other.detach()) {} + movable_primitive& operator=(movable_primitive const&) = default; + movable_primitive& operator=(movable_primitive&& other) + { + value = other.detach(); + return *this; + } + + T detach() { return std::exchange(value, empty_value); } + }; + template struct arg { diff --git a/test/test/await_adapter.cpp b/test/test/await_adapter.cpp index 809567695..701bc6b15 100644 --- a/test/test/await_adapter.cpp +++ b/test/test/await_adapter.cpp @@ -18,13 +18,9 @@ namespace static handle signal{ CreateEventW(nullptr, false, false, nullptr) }; - IAsyncAction OtherForegroundAsync() + IAsyncAction OtherForegroundAsync(DispatcherQueue dispatcher) { - // Simple coroutine that completes on a unique STA thread. - - auto controller = DispatcherQueueController::CreateOnDedicatedThread(); - auto dispatcher = controller.DispatcherQueue(); - + // Simple coroutine that completes on the specified STA thread. co_await resume_foreground(dispatcher); } @@ -35,37 +31,37 @@ namespace co_await resume_background(); } - IAsyncAction ForegroundAsync(DispatcherQueue dispatcher) + // Coroutine that completes on dispatcher1, while potentially blocking dispatcher2. + IAsyncAction ForegroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2) { REQUIRE(!is_sta()); - co_await resume_foreground(dispatcher); + co_await resume_foreground(dispatcher1); REQUIRE(is_sta()); // This exercises one STA thread waiting on another thus one context callback // completing on another. uint32_t id = GetCurrentThreadId(); - co_await OtherForegroundAsync(); + co_await OtherForegroundAsync(dispatcher2); REQUIRE(id == GetCurrentThreadId()); - // This just avoids the ForegroundAsync coroutine completing before - // BackgroundAsync waits on the result, forcing the Completed handler - // to be called on the foreground thread. This just makes the test - // success/failure more predictable. + // This Sleep() makes it more likely that the caller will actually suspend in await_suspend, + // so that the Completed handler triggers a resumption from the dispatcher1 thread. Sleep(100); } - fire_and_forget SignalFromForeground(DispatcherQueue dispatcher) + fire_and_forget SignalFromForeground(DispatcherQueue dispatcher1) { REQUIRE(!is_sta()); - co_await resume_foreground(dispatcher); + co_await resume_foreground(dispatcher1); REQUIRE(is_sta()); - // Previously, this signal was never raised because the foreground thread - // was always blocked waiting for ContextCallback to return. + // Previously, we never got here because of a deadlock: + // The dispatcher1 thread was blocked waiting for ContextCallback to return, + // but the ContextCallback is waiting for this event to get signaled. REQUIRE(SetEvent(signal.get())); } - IAsyncAction BackgroundAsync(DispatcherQueue dispatcher) + IAsyncAction BackgroundAsync(DispatcherQueue dispatcher1, DispatcherQueue dispatcher2) { // Switch to a background (MTA) thread. co_await resume_background(); @@ -76,19 +72,19 @@ namespace co_await OtherBackgroundAsync(); REQUIRE(!is_sta()); - // Wait for a coroutine that completes on a foreground (STA) thread. - co_await ForegroundAsync(dispatcher); + // Wait for a coroutine that completes on a the dispatcher1 thread (STA). + co_await ForegroundAsync(dispatcher1, dispatcher2); // Resumption should automatically switch to a background (MTA) thread - // without blocking the Completed handler (which would in turn block the foreground thread). + // without blocking the Completed handler (which would in turn block the dispatcher1 thread). REQUIRE(!is_sta()); - // Attempt to signal from the foreground thread under the assumption - // that the foreground thread is not blocked. - SignalFromForeground(dispatcher); + // Attempt to signal from the dispatcher1 thread under the assumption + // that the dispatcher1 thread is not blocked. + SignalFromForeground(dispatcher1); - // Block the background (MTA) thread indefinitely until the signal is raied. - // Previously this would deadlock. + // Block the background (MTA) thread indefinitely until the signal is raised. + // Previously this would hang because the signal never got raised. REQUIRE(WAIT_OBJECT_0 == WaitForSingleObject(signal.get(), INFINITE)); } } @@ -99,9 +95,44 @@ TEST_CASE("await_adapter", "[.clang-crash]") #else TEST_CASE("await_adapter") #endif +{ + auto controller1 = DispatcherQueueController::CreateOnDedicatedThread(); + auto controller2 = DispatcherQueueController::CreateOnDedicatedThread(); + + BackgroundAsync(controller1.DispatcherQueue(), controller2.DispatcherQueue()).get(); + controller1.ShutdownQueueAsync().get(); + controller2.ShutdownQueueAsync().get(); +} + +namespace +{ + IAsyncAction OtherBackgroundDelayAsync() + { + // Simple coroutine that completes on some MTA thread after a brief delay + // to ensure that the caller has suspended. + + co_await resume_after(100ms); + } + + IAsyncAction AgileAsync(DispatcherQueue dispatcher) + { + // Switch to the STA. + co_await resume_foreground(dispatcher); + REQUIRE(is_sta()); + + // Ask for agile resumption of a coroutine that finishes on a background thread. + // Add a 100ms delay to ensure we suspend. + co_await resume_agile(OtherBackgroundDelayAsync()); + // We should be on the background thread now. + REQUIRE(!is_sta()); + } +} + +TEST_CASE("await_adapter_agile") { auto controller = DispatcherQueueController::CreateOnDedicatedThread(); auto dispatcher = controller.DispatcherQueue(); - BackgroundAsync(dispatcher).get(); + AgileAsync(dispatcher).get(); + controller.ShutdownQueueAsync().get(); }