Skip to content

Commit fa28acf

Browse files
snarkmasterfacebook-github-bot
authored andcommitted
Eliminate promise-punning UB from getErrorHandle
Summary: As discussed in the preceding D82907354, it is quite hard to eliminate promise-punning UB from `TaskWrapper::get_return_object`, so we just leave it be with the safeguard of `is_promise_type_punning_safe`. However, there's no inherent reason for `getErrorHandle` to interact with `coroutine_handle<Promise>`. It actually only either needs the `Promise&` (which may be stored as a member of a wrapper promise without issue), or the type-erased `coroutine_handle<>`. Furthermore, the latter is only needed on the "no-op" path, i.e. the per-coro implementation does not need to see a handle at all. This diff renames `getErrorHandle` implementations to `getErrorHandleImpl` for clarity, and switches them all to the new, UB-free design. Reviewed By: ispeters Differential Revision: D82926069 fbshipit-source-id: 3230102ffa12ae466ca822ce6a058092c720ac5d
1 parent 27d9ad8 commit fa28acf

File tree

5 files changed

+128
-42
lines changed

5 files changed

+128
-42
lines changed

third-party/folly/src/folly/coro/AsyncGenerator.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -565,9 +565,8 @@ class AsyncGeneratorPromise final
565565
class YieldAwaiter {
566566
public:
567567
bool await_ready() noexcept { return false; }
568-
coroutine_handle<> await_suspend(
569-
coroutine_handle<AsyncGeneratorPromise> h) noexcept {
570-
AsyncGeneratorPromise& promise = h.promise();
568+
coroutine_handle<> await_suspend_promise(
569+
AsyncGeneratorPromise& promise) noexcept {
571570
// Pop AsyncStackFrame first as clearContext() clears the frame state.
572571
folly::popAsyncStackFrameCallee(promise.getAsyncFrame());
573572
promise.clearContext();
@@ -578,6 +577,10 @@ class AsyncGeneratorPromise final
578577
}
579578
return promise.continuation_.getHandle();
580579
}
580+
coroutine_handle<> await_suspend(
581+
coroutine_handle<AsyncGeneratorPromise> h) noexcept {
582+
return await_suspend_promise(h.promise());
583+
}
581584
void await_resume() noexcept {}
582585
};
583586

@@ -815,18 +818,17 @@ class AsyncGeneratorPromise final
815818

816819
folly::AsyncStackFrame& getAsyncFrame() noexcept { return asyncFrame_; }
817820

818-
static ExtendedCoroutineHandle::ErrorHandle getErrorHandle(
821+
static std::optional<ExtendedCoroutineHandle::ErrorHandle> getErrorHandleImpl(
819822
AsyncGeneratorPromise& me, exception_wrapper& ex) {
820823
if (me.bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
821824
auto yieldAwaiter = me.yield_value(co_error(std::move(ex)));
822825
DCHECK(!yieldAwaiter.await_ready());
823-
return {
824-
yieldAwaiter.await_suspend(
825-
coroutine_handle<AsyncGeneratorPromise>::from_promise(me)),
826+
return ExtendedCoroutineHandle::ErrorHandle{
827+
yieldAwaiter.await_suspend_promise(me),
826828
// yieldAwaiter.await_suspend pops a frame
827829
me.getAsyncFrame().getParentFrame()};
828830
}
829-
return {coroutine_handle<AsyncGeneratorPromise>::from_promise(me), nullptr};
831+
return std::nullopt;
830832
}
831833

832834
private:

third-party/folly/src/folly/coro/Coroutine.h

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#pragma once
1818

19+
#include <optional>
1920
#include <type_traits>
2021

2122
#if __has_include(<variant>)
@@ -298,9 +299,53 @@ inline bool detect_promise_return_object_eager_conversion() {
298299
template <typename>
299300
class ExtendedCoroutinePromiseCrtp;
300301

302+
namespace detail {
303+
template <typename, typename, typename>
304+
class TaskPromiseWrapperBase;
305+
}
306+
301307
// Extended version of coroutine_handle<void>
302308
// Assumes (and enforces) assumption that coroutine_handle is a pointer
303309
class ExtendedCoroutineHandle {
310+
protected:
311+
template <typename>
312+
friend class ExtendedCoroutinePromiseCrtp;
313+
template <typename, typename, typename>
314+
friend class detail::TaskPromiseWrapperBase;
315+
// This passkey aims to stop end users from calling `getPromiseBase`, which
316+
// is an unsafe implementation detail, and to prevent overload ambiguity.
317+
//
318+
// It also doubles as the sigil for `use_extended_handle_concept`, another
319+
// private detail.
320+
class PrivateTag {
321+
private:
322+
friend ExtendedCoroutineHandle;
323+
PrivateTag() = default;
324+
};
325+
326+
private:
327+
// SFINAE detection for the `use_extended_handle_concept` member type alias
328+
// that classes implementing `getErrorHandle` must expose. We don't want to
329+
// use any kind of common base on `TaskWrapperPromise`, be it non-empty
330+
// `PromiseBase`, or a dedicated empty tag, since either one would break
331+
// empty-base optimization.
332+
333+
template <typename T>
334+
using use_extended_handle_of_ = typename T::use_extended_handle_concept;
335+
336+
template <typename T, typename Void = void>
337+
struct use_extended_handle {
338+
static_assert(
339+
require_sizeof<T>, "`use_extended_handle` on incomplete type");
340+
static constexpr bool value = false;
341+
};
342+
343+
template <typename T>
344+
struct use_extended_handle<T, void_t<use_extended_handle_of_<T>>> {
345+
static constexpr bool value =
346+
std::is_same_v<use_extended_handle_of_<T>, PrivateTag>;
347+
};
348+
304349
public:
305350
using ErrorHandle = std::pair<ExtendedCoroutineHandle, AsyncStackFrame*>;
306351

@@ -310,15 +355,17 @@ class ExtendedCoroutineHandle {
310355
template <typename>
311356
friend class ExtendedCoroutinePromiseCrtp;
312357

313-
using Fn = ErrorHandle(PromiseBase*, exception_wrapper& ex);
358+
using Fn = std::optional<ErrorHandle>(PromiseBase*, exception_wrapper& ex);
314359

315360
explicit PromiseBase(Fn* fn) : getErrorHandlePtr_(fn) {}
316361
~PromiseBase() = default;
317362

318-
// A manual vtable with 1 function. The benefit over virtual inheritance
319-
// is that derived classes like `TaskPromise` don't have to be `final` in
320-
// order to for the compiler to treat them as non-polymorphic.
321-
// Specifically, this enables `SafeTask`, a more type-safe `Task`.
363+
// A manual vtable with 1 function. Benefits over virtual inheritance:
364+
// - `TaskWrapperPromise` can implement `getErrorHandle` without bloating
365+
// itself with with a vtable it does not need.
366+
// - A tiny binary size win.
367+
// - Derived classes like `TaskPromise` don't have to be `final` in order
368+
// for the compiler to treat them as non-polymorphic.
322369
Fn* getErrorHandlePtr_;
323370
};
324371

@@ -332,10 +379,10 @@ class ExtendedCoroutineHandle {
332379

333380
template <
334381
typename Promise,
335-
std::enable_if_t<std::is_base_of_v<PromiseBase, Promise>, int> = 0>
382+
std::enable_if_t<use_extended_handle<Promise>::value, int> = 0>
336383
/*implicit*/ ExtendedCoroutineHandle(Promise* promise) noexcept
337384
: basic_(coroutine_handle<Promise>::from_promise(*promise)),
338-
extended_(static_cast<PromiseBase*>(promise)) {}
385+
extended_(Promise::getPromiseBase(PrivateTag{}, promise)) {}
339386

340387
ExtendedCoroutineHandle() noexcept = default;
341388

@@ -347,7 +394,9 @@ class ExtendedCoroutineHandle {
347394

348395
ErrorHandle getErrorHandle(exception_wrapper& ex) {
349396
if (extended_) {
350-
return extended_->getErrorHandlePtr_(extended_, ex);
397+
if (auto res = extended_->getErrorHandlePtr_(extended_, ex)) {
398+
return *res;
399+
}
351400
}
352401
return {basic_, nullptr};
353402
}
@@ -357,8 +406,8 @@ class ExtendedCoroutineHandle {
357406
private:
358407
template <typename Promise>
359408
static auto fromBasic(coroutine_handle<Promise> handle) noexcept {
360-
if constexpr (std::is_convertible_v<Promise*, PromiseBase*>) {
361-
return static_cast<PromiseBase*>(&handle.promise());
409+
if constexpr (use_extended_handle<Promise>::value) {
410+
return Promise::getPromiseBase(PrivateTag{}, &handle.promise());
362411
} else {
363412
return nullptr;
364413
}
@@ -368,24 +417,42 @@ class ExtendedCoroutineHandle {
368417
PromiseBase* extended_{nullptr};
369418
};
370419

371-
// folly::coro types are expected to implement this extended promise interface:
372-
// (1) Publicly inherit from `ExtendedCoroutinePromiseCrtp<YourPromise>`,
373-
// (2) Implement this static method on `YourPromise`:
420+
// folly::coro types are expected to implement this extended promise interface.
421+
//
422+
// It allows types to provide a more efficient resumption path when they know
423+
// they will be receiving an error result from the awaitee.
424+
//
425+
// First, publicly inherit from `ExtendedCoroutinePromiseCrtp<YourPromise>`,
426+
// Second, implement this static method on `YourPromise`:
374427
//
375-
// static ExtendedCoroutineHandle::ErrorHandle getErrorHandle(
376-
// YourPromise&, exception_wrapper&)
428+
// static std::optional<ExtendedCoroutineHandle::ErrorHandle>
429+
// getErrorHandleImpl(YourPromise&, exception_wrapper&);
377430
//
378-
// Rationale: Types may provide a more efficient resumption path when they
379-
// know they will be receiving an error result from the awaitee. If they
380-
// do, they might also update the active stack frame.
431+
// Return `std::nullopt` to avoid changing the resumption path. Otherwise,
432+
// return the `ExtendedCoroutineHandle` to resume & the active stack frame.
433+
//
434+
// DANGER: `YourPromise& promise` is a promise instance, but it might NOT
435+
// directly correspond to a coro frame. For example, if your coro is wrapped,
436+
// that promise is a **member** inside a larger wrapper promise for the coro.
437+
// Therefore, you must NOT call `coroutine_handle<...>::from_promise(promise)`.
438+
// In the future, the true handle could be supplied, but none of the current
439+
// coros required it.
381440
template <typename Promise>
382441
class ExtendedCoroutinePromiseCrtp
383442
: public ExtendedCoroutineHandle::PromiseBase {
443+
public:
444+
using use_extended_handle_concept = ExtendedCoroutineHandle::PrivateTag;
445+
446+
static ExtendedCoroutineHandle::PromiseBase* getPromiseBase(
447+
ExtendedCoroutineHandle::PrivateTag, ExtendedCoroutinePromiseCrtp* me) {
448+
return me;
449+
}
450+
384451
protected:
385452
using PromiseBase = typename ExtendedCoroutineHandle::PromiseBase;
386453
ExtendedCoroutinePromiseCrtp()
387-
: PromiseBase(+[](PromiseBase* promise, exception_wrapper& ex) {
388-
return Promise::getErrorHandle(*static_cast<Promise*>(promise), ex);
454+
: PromiseBase(+[](PromiseBase* p, exception_wrapper& ex) {
455+
return Promise::getErrorHandleImpl(*static_cast<Promise*>(p), ex);
389456
}) {}
390457
~ExtendedCoroutinePromiseCrtp() = default;
391458
};

third-party/folly/src/folly/coro/Task.h

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,7 @@ class TaskPromiseBase {
8282
bool await_ready() noexcept { return false; }
8383

8484
template <typename Promise>
85-
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES coroutine_handle<>
86-
await_suspend(coroutine_handle<Promise> coro) noexcept {
87-
auto& promise = coro.promise();
85+
coroutine_handle<> await_suspend_promise(Promise& promise) noexcept {
8886
// If ScopeExitTask has been attached, then we expect that the
8987
// ScopeExitTask will handle the lifetime of the async stack. See
9088
// ScopeExitTaskPromise's FinalAwaiter for more details.
@@ -114,6 +112,12 @@ class TaskPromiseBase {
114112
return promise.continuationRef(privateTag()).getHandle();
115113
}
116114

115+
template <typename Promise>
116+
FOLLY_CORO_AWAIT_SUSPEND_NONTRIVIAL_ATTRIBUTES coroutine_handle<>
117+
await_suspend(coroutine_handle<Promise> coro) noexcept {
118+
return await_suspend_promise(coro.promise());
119+
}
120+
117121
[[noreturn]] void await_resume() noexcept { folly::assume_unreachable(); }
118122
};
119123

@@ -276,18 +280,17 @@ class TaskPromiseCrtpBase
276280
return do_safe_point(*this);
277281
}
278282

279-
static ExtendedCoroutineHandle::ErrorHandle getErrorHandle(
283+
static std::optional<ExtendedCoroutineHandle::ErrorHandle> getErrorHandleImpl(
280284
Promise& me, exception_wrapper& ex) {
281285
if (me.bypassExceptionThrowing_ == BypassExceptionThrowing::ACTIVE) {
282286
auto finalAwaiter = me.yield_value(co_error(std::move(ex)));
283287
DCHECK(!finalAwaiter.await_ready());
284-
return {
285-
finalAwaiter.await_suspend(
286-
coroutine_handle<Promise>::from_promise(me)),
288+
return ExtendedCoroutineHandle::ErrorHandle{
289+
finalAwaiter.await_suspend_promise(me),
287290
// finalAwaiter.await_suspend pops a frame
288291
me.getAsyncFrame().getParentFrame()};
289292
}
290-
return {coroutine_handle<Promise>::from_promise(me), nullptr};
293+
return std::nullopt;
291294
}
292295

293296
protected:
@@ -1003,6 +1006,10 @@ Task<drop_unit_t<T>> makeResultTask(Try<T> t) {
10031006
template <typename Promise, typename T>
10041007
inline Task<T>
10051008
detail::TaskPromiseCrtpBase<Promise, T>::get_return_object() noexcept {
1009+
// Watch out: When used with `TaskWrapper`, this relies on "practically safe"
1010+
// UB wherein this handle is only valid because `TaskPromise` and the true
1011+
// "wrapper promise" of the wrapper coro coincide in layout exactly.
1012+
// Documented in `TaskPromiseWrapperBase::is_promise_type_punning_safe`.
10061013
return Task<T>{
10071014
coroutine_handle<Promise>::from_promise(*static_cast<Promise*>(this))};
10081015
}

third-party/folly/src/folly/coro/TaskWrapper.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,10 @@ class TaskPromiseWrapperBase {
169169
using TaskWrapperInnerPromise = Promise;
170170

171171
WrapperTask get_return_object() noexcept {
172-
// NB: See the function doc. It'd be nice to have the `static_assert` at
173-
// class scope, but the type is still incomplete at that point.
172+
// CRITICAL: This assert justifies why it is practically safe to rely on
173+
// the `from_promise` UB in `TaskPromiseCrtpBase::get_return_object`.
174+
//
175+
// PS The assert isn't at class scope, since the type would be incomplete.
174176
static_assert(is_promise_type_punning_safe());
175177
return WrapperTask{promise_.get_return_object()};
176178
}
@@ -231,6 +233,9 @@ class TaskPromiseWrapper
231233
public:
232234
template <typename U = T> // see "`co_return` with implicit ctor" test
233235
auto return_value(U&& value) {
236+
static_assert( // See `is_promise_type_punning_safe` for rationale
237+
require_sizeof<TaskPromiseWrapper> ==
238+
require_sizeof<TaskPromiseWrapperBase<T, WrapperTask, Promise>>);
234239
return this->promise_.return_value(std::forward<U>(value));
235240
}
236241
};
@@ -243,7 +248,12 @@ class TaskPromiseWrapper<void, WrapperTask, Promise>
243248
~TaskPromiseWrapper() = default;
244249

245250
public:
246-
void return_void() noexcept { this->promise_.return_void(); }
251+
void return_void() noexcept {
252+
static_assert( // See `is_promise_type_punning_safe` for rationale
253+
require_sizeof<TaskPromiseWrapper> ==
254+
require_sizeof<TaskPromiseWrapperBase<void, WrapperTask, Promise>>);
255+
this->promise_.return_void();
256+
}
247257
};
248258

249259
// Mixin for TaskWrapper.h configs for `Task` & `TaskWithExecutor` types

third-party/folly/src/folly/coro/ViaIfAsync.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,14 @@ class ViaCoroutine {
154154

155155
folly::AsyncStackFrame& getLeafFrame() noexcept { return leafFrame_; }
156156

157-
static ExtendedCoroutineHandle::ErrorHandle getErrorHandle(
158-
promise_type& me, exception_wrapper& ex) {
157+
static std::optional<ExtendedCoroutineHandle::ErrorHandle>
158+
getErrorHandleImpl(promise_type& me, exception_wrapper& ex) {
159159
auto [handle, frame] = me.continuation_.getErrorHandle(ex);
160160
me.setContinuation(handle);
161161
if (frame && IsStackAware) {
162162
me.leafFrame_.setParentFrame(*frame);
163163
}
164-
return {coroutine_handle<promise_type>::from_promise(me), nullptr};
164+
return std::nullopt;
165165
}
166166
};
167167

0 commit comments

Comments
 (0)