1616
1717#pragma once
1818
19+ #include < optional>
1920#include < type_traits>
2021
2122#if __has_include(<variant>)
@@ -298,9 +299,53 @@ inline bool detect_promise_return_object_eager_conversion() {
298299template <typename >
299300class ExtendedCoroutinePromiseCrtp ;
300301
302+ namespace detail {
303+ template <typename , typename , typename >
304+ class TaskPromiseWrapperBase ;
305+ }
306+
301307// Extended version of coroutine_handle<void>
302308// Assumes (and enforces) assumption that coroutine_handle is a pointer
303309class ExtendedCoroutineHandle {
310+ protected:
311+ template <typename >
312+ friend class ExtendedCoroutinePromiseCrtp ;
313+ template <typename , typename , typename >
314+ friend class detail ::TaskPromiseWrapperBase;
315+ // This passkey aims to stop end users from calling `getPromiseBase`, which
316+ // is an unsafe implementation detail, and to prevent overload ambiguity.
317+ //
318+ // It also doubles as the sigil for `use_extended_handle_concept`, another
319+ // private detail.
320+ class PrivateTag {
321+ private:
322+ friend ExtendedCoroutineHandle;
323+ PrivateTag () = default ;
324+ };
325+
326+ private:
327+ // SFINAE detection for the `use_extended_handle_concept` member type alias
328+ // that classes implementing `getErrorHandle` must expose. We don't want to
329+ // use any kind of common base on `TaskWrapperPromise`, be it non-empty
330+ // `PromiseBase`, or a dedicated empty tag, since either one would break
331+ // empty-base optimization.
332+
333+ template <typename T>
334+ using use_extended_handle_of_ = typename T::use_extended_handle_concept;
335+
336+ template <typename T, typename Void = void >
337+ struct use_extended_handle {
338+ static_assert (
339+ require_sizeof<T>, " `use_extended_handle` on incomplete type" );
340+ static constexpr bool value = false ;
341+ };
342+
343+ template <typename T>
344+ struct use_extended_handle <T, void_t <use_extended_handle_of_<T>>> {
345+ static constexpr bool value =
346+ std::is_same_v<use_extended_handle_of_<T>, PrivateTag>;
347+ };
348+
304349 public:
305350 using ErrorHandle = std::pair<ExtendedCoroutineHandle, AsyncStackFrame*>;
306351
@@ -310,15 +355,17 @@ class ExtendedCoroutineHandle {
310355 template <typename >
311356 friend class ExtendedCoroutinePromiseCrtp ;
312357
313- using Fn = ErrorHandle(PromiseBase*, exception_wrapper& ex);
358+ using Fn = std::optional< ErrorHandle> (PromiseBase*, exception_wrapper& ex);
314359
315360 explicit PromiseBase (Fn* fn) : getErrorHandlePtr_(fn) {}
316361 ~PromiseBase () = default ;
317362
318- // A manual vtable with 1 function. The benefit over virtual inheritance
319- // is that derived classes like `TaskPromise` don't have to be `final` in
320- // order to for the compiler to treat them as non-polymorphic.
321- // Specifically, this enables `SafeTask`, a more type-safe `Task`.
363+ // A manual vtable with 1 function. Benefits over virtual inheritance:
364+ // - `TaskWrapperPromise` can implement `getErrorHandle` without bloating
365+ // itself with with a vtable it does not need.
366+ // - A tiny binary size win.
367+ // - Derived classes like `TaskPromise` don't have to be `final` in order
368+ // for the compiler to treat them as non-polymorphic.
322369 Fn* getErrorHandlePtr_;
323370 };
324371
@@ -332,10 +379,10 @@ class ExtendedCoroutineHandle {
332379
333380 template <
334381 typename Promise,
335- std::enable_if_t <std::is_base_of_v<PromiseBase, Promise>, int > = 0 >
382+ std::enable_if_t <use_extended_handle< Promise>::value , int > = 0 >
336383 /* implicit*/ ExtendedCoroutineHandle(Promise* promise) noexcept
337384 : basic_(coroutine_handle<Promise>::from_promise(*promise)),
338- extended_ (static_cast <PromiseBase*>( promise)) {}
385+ extended_ (Promise::getPromiseBase(PrivateTag{}, promise)) {}
339386
340387 ExtendedCoroutineHandle () noexcept = default;
341388
@@ -347,7 +394,9 @@ class ExtendedCoroutineHandle {
347394
348395 ErrorHandle getErrorHandle (exception_wrapper& ex) {
349396 if (extended_) {
350- return extended_->getErrorHandlePtr_ (extended_, ex);
397+ if (auto res = extended_->getErrorHandlePtr_ (extended_, ex)) {
398+ return *res;
399+ }
351400 }
352401 return {basic_, nullptr };
353402 }
@@ -357,8 +406,8 @@ class ExtendedCoroutineHandle {
357406 private:
358407 template <typename Promise>
359408 static auto fromBasic (coroutine_handle<Promise> handle) noexcept {
360- if constexpr (std::is_convertible_v <Promise*, PromiseBase*> ) {
361- return static_cast <PromiseBase*>( &handle.promise ());
409+ if constexpr (use_extended_handle <Promise>::value ) {
410+ return Promise::getPromiseBase (PrivateTag{}, &handle.promise ());
362411 } else {
363412 return nullptr ;
364413 }
@@ -368,24 +417,42 @@ class ExtendedCoroutineHandle {
368417 PromiseBase* extended_{nullptr };
369418};
370419
371- // folly::coro types are expected to implement this extended promise interface:
372- // (1) Publicly inherit from `ExtendedCoroutinePromiseCrtp<YourPromise>`,
373- // (2) Implement this static method on `YourPromise`:
420+ // folly::coro types are expected to implement this extended promise interface.
421+ //
422+ // It allows types to provide a more efficient resumption path when they know
423+ // they will be receiving an error result from the awaitee.
424+ //
425+ // First, publicly inherit from `ExtendedCoroutinePromiseCrtp<YourPromise>`,
426+ // Second, implement this static method on `YourPromise`:
374427//
375- // static ExtendedCoroutineHandle::ErrorHandle getErrorHandle(
376- // YourPromise&, exception_wrapper&)
428+ // static std::optional< ExtendedCoroutineHandle::ErrorHandle>
429+ // getErrorHandleImpl( YourPromise&, exception_wrapper&);
377430//
378- // Rationale: Types may provide a more efficient resumption path when they
379- // know they will be receiving an error result from the awaitee. If they
380- // do, they might also update the active stack frame.
431+ // Return `std::nullopt` to avoid changing the resumption path. Otherwise,
432+ // return the `ExtendedCoroutineHandle` to resume & the active stack frame.
433+ //
434+ // DANGER: `YourPromise& promise` is a promise instance, but it might NOT
435+ // directly correspond to a coro frame. For example, if your coro is wrapped,
436+ // that promise is a **member** inside a larger wrapper promise for the coro.
437+ // Therefore, you must NOT call `coroutine_handle<...>::from_promise(promise)`.
438+ // In the future, the true handle could be supplied, but none of the current
439+ // coros required it.
381440template <typename Promise>
382441class ExtendedCoroutinePromiseCrtp
383442 : public ExtendedCoroutineHandle::PromiseBase {
443+ public:
444+ using use_extended_handle_concept = ExtendedCoroutineHandle::PrivateTag;
445+
446+ static ExtendedCoroutineHandle::PromiseBase* getPromiseBase (
447+ ExtendedCoroutineHandle::PrivateTag, ExtendedCoroutinePromiseCrtp* me) {
448+ return me;
449+ }
450+
384451 protected:
385452 using PromiseBase = typename ExtendedCoroutineHandle::PromiseBase;
386453 ExtendedCoroutinePromiseCrtp ()
387- : PromiseBase(+[](PromiseBase* promise , exception_wrapper& ex) {
388- return Promise::getErrorHandle (*static_cast <Promise*>(promise ), ex);
454+ : PromiseBase(+[](PromiseBase* p , exception_wrapper& ex) {
455+ return Promise::getErrorHandleImpl (*static_cast <Promise*>(p ), ex);
389456 }) {}
390457 ~ExtendedCoroutinePromiseCrtp () = default ;
391458};
0 commit comments