Skip to content

Commit

Permalink
<future>: Make packaged_task accept move-only functors (#4946)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <stl@microsoft.com>
  • Loading branch information
frederick-vs-ja and StephanTLavavej authored Oct 12, 2024
1 parent 23dc7e3 commit d08e31c
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 121 deletions.
32 changes: 30 additions & 2 deletions stl/inc/functional
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,10 @@ public:
private:
_Mybase* _Copy(void* _Where) const override {
auto& _Myax = _Mypair._Get_first();
if constexpr (_Is_large<_Func_impl>) {
if constexpr (!is_copy_constructible_v<_Callable>) { // used exclusively for packaged_task
(void) _Myax;
_CSTD abort(); // shouldn't be called, see GH-3888
} else if constexpr (_Is_large<_Func_impl>) {
_Myalty _Rebound(_Myax);
_Alloc_construct_ptr<_Myalty> _Constructor{_Rebound};
_Constructor._Allocate();
Expand Down Expand Up @@ -854,7 +857,9 @@ public:

private:
_Mybase* _Copy(void* _Where) const override {
if constexpr (_Is_large<_Func_impl_no_alloc>) {
if constexpr (!is_copy_constructible_v<_Callable>) { // used exclusively for packaged_task
_CSTD abort(); // shouldn't be called, see GH-3888
} else if constexpr (_Is_large<_Func_impl_no_alloc>) {
return _STD _Global_new<_Func_impl_no_alloc>(_Callee);
} else {
return ::new (_Where) _Func_impl_no_alloc(_Callee);
Expand Down Expand Up @@ -1070,6 +1075,10 @@ _NON_MEMBER_CALL(_GET_FUNCTION_IMPL_NOEXCEPT, X1, X2, X3)
#undef _GET_FUNCTION_IMPL_NOEXCEPT
#endif // defined(__cpp_noexcept_function_type)

struct _Secret_copyability_ignoring_tag { // used exclusively for packaged_task
explicit _Secret_copyability_ignoring_tag() = default;
};

_EXPORT_STD template <class _Fty>
class function : public _Get_function_impl<_Fty>::type { // wrapper for callable objects
private:
Expand All @@ -1086,6 +1095,15 @@ public:

template <class _Fx, typename _Mybase::template _Enable_if_callable_t<_Fx, function> = 0>
function(_Fx&& _Func) {
static_assert(is_copy_constructible_v<decay_t<_Fx>>,
"The target function object type must be copy constructible (N4988 [func.wrap.func.con]/10.1).");
this->_Reset(_STD forward<_Fx>(_Func));
}

template <class _SecretTag, class _Fx,
enable_if_t<is_same_v<_SecretTag, _Secret_copyability_ignoring_tag>, int> = 0,
typename _Mybase::template _Enable_if_callable_t<_Fx, function> = 0>
explicit function(_SecretTag, _Fx&& _Func) { // used exclusively for packaged_task
this->_Reset(_STD forward<_Fx>(_Func));
}

Expand All @@ -1103,6 +1121,16 @@ public:

template <class _Fx, class _Alloc, typename _Mybase::template _Enable_if_callable_t<_Fx, function> = 0>
function(allocator_arg_t, const _Alloc& _Ax, _Fx&& _Func) {
static_assert(is_copy_constructible_v<decay_t<_Fx>>,
"The target function object type must be copy constructible (N4140 [func.wrap.func.con]/7).");
this->_Reset_alloc(_STD forward<_Fx>(_Func), _Ax);
}

template <class _SecretTag, class _Fx, class _Alloc,
enable_if_t<is_same_v<_SecretTag, _Secret_copyability_ignoring_tag>, int> = 0,
typename _Mybase::template _Enable_if_callable_t<_Fx, function> = 0>
explicit function(_SecretTag, allocator_arg_t, const _Alloc& _Ax, _Fx&& _Func) {
// used exclusively for packaged_task
this->_Reset_alloc(_STD forward<_Fx>(_Func), _Ax);
}
#endif // _HAS_FUNCTION_ALLOCATOR_SUPPORT
Expand Down
182 changes: 64 additions & 118 deletions stl/inc/future
Original file line number Diff line number Diff line change
Expand Up @@ -459,118 +459,64 @@ void _State_deleter<_Ty, _Derived, _Alloc>::_Delete(_Associated_state<_Ty>* _Sta
_STD _Delete_plain_internal(_Del_alloc, this);
}

template <class>
class _Packaged_state;

template <class _Ret, class... _ArgTypes>
class _Packaged_state<_Ret(_ArgTypes...)>
: public _Associated_state<_Ret> { // class for managing associated asynchronous state for packaged_task
public:
using _Mybase = _Associated_state<_Ret>;
using _Mydel = typename _Mybase::_Mydel;

template <class _Fty2>
_Packaged_state(_Fty2&& _Fnarg) : _Fn(_STD forward<_Fty2>(_Fnarg)) {}

#if _HAS_FUNCTION_ALLOCATOR_SUPPORT
template <class _Fty2, class _Alloc>
_Packaged_state(_Fty2&& _Fnarg, const _Alloc& _Al, _Mydel* _Dp)
: _Mybase(_Dp), _Fn(allocator_arg, _Al, _STD forward<_Fty2>(_Fnarg)) {}
#endif // _HAS_FUNCTION_ALLOCATOR_SUPPORT

void _Call_deferred(_ArgTypes... _Args) { // set deferred call
_TRY_BEGIN
// call function object and catch exceptions
this->_Set_value(_Fn(_STD forward<_ArgTypes>(_Args)...), true);
_CATCH_ALL
// function object threw exception; record result
this->_Set_exception(_STD current_exception(), true);
_CATCH_END
}

void _Call_immediate(_ArgTypes... _Args) { // call function object
_TRY_BEGIN
// call function object and catch exceptions
this->_Set_value(_Fn(_STD forward<_ArgTypes>(_Args)...), false);
_CATCH_ALL
// function object threw exception; record result
this->_Set_exception(_STD current_exception(), false);
_CATCH_END
}
template <class _Fret>
struct _P_arg_type { // type for functions returning T
using type = _Fret;
};

const auto& _Get_fn() const {
return _Fn;
}
template <class _Fret>
struct _P_arg_type<_Fret&> { // type for functions returning reference to T
using type = _Fret*;
};

private:
function<_Ret(_ArgTypes...)> _Fn;
template <>
struct _P_arg_type<void> { // type for functions returning void
using type = int;
};

template <class>
class _Packaged_state;

// class for managing associated asynchronous state for packaged_task
template <class _Ret, class... _ArgTypes>
class _Packaged_state<_Ret&(_ArgTypes...)>
: public _Associated_state<_Ret*> { // class for managing associated asynchronous state for packaged_task
class _Packaged_state<_Ret(_ArgTypes...)> : public _Associated_state<typename _P_arg_type<_Ret>::type> {
public:
using _Mybase = _Associated_state<_Ret*>;
using _Mydel = typename _Mybase::_Mydel;

template <class _Fty2>
_Packaged_state(_Fty2&& _Fnarg) : _Fn(_STD forward<_Fty2>(_Fnarg)) {}

#if _HAS_FUNCTION_ALLOCATOR_SUPPORT
template <class _Fty2, class _Alloc>
_Packaged_state(_Fty2&& _Fnarg, const _Alloc& _Al, _Mydel* _Dp)
: _Mybase(_Dp), _Fn(allocator_arg, _Al, _STD forward<_Fty2>(_Fnarg)) {}
#endif // _HAS_FUNCTION_ALLOCATOR_SUPPORT
using _Mybase = _Associated_state<typename _P_arg_type<_Ret>::type>;
using _Mydel = typename _Mybase::_Mydel;
using _Function_type = function<_Ret(_ArgTypes...)>; // TRANSITION, ABI, should not use std::function

void _Call_deferred(_ArgTypes... _Args) { // set deferred call
_TRY_BEGIN
// call function object and catch exceptions
this->_Set_value(_STD addressof(_Fn(_STD forward<_ArgTypes>(_Args)...)), true);
_CATCH_ALL
// function object threw exception; record result
this->_Set_exception(_STD current_exception(), true);
_CATCH_END
}

void _Call_immediate(_ArgTypes... _Args) { // call function object
_TRY_BEGIN
// call function object and catch exceptions
this->_Set_value(_STD addressof(_Fn(_STD forward<_ArgTypes>(_Args)...)), false);
_CATCH_ALL
// function object threw exception; record result
this->_Set_exception(_STD current_exception(), false);
_CATCH_END
}
explicit _Packaged_state(const _Function_type& _Fnarg) : _Fn(_Fnarg) {}

const auto& _Get_fn() const {
return _Fn;
}
explicit _Packaged_state(_Function_type&& _Fnarg) noexcept : _Fn(_STD move(_Fnarg)) {}

private:
function<_Ret&(_ArgTypes...)> _Fn;
};
template <class _Fty2, enable_if_t<!is_same_v<_Remove_cvref_t<_Fty2>, _Function_type>, int> = 0>
explicit _Packaged_state(_Fty2&& _Fnarg) : _Fn(_Secret_copyability_ignoring_tag{}, _STD forward<_Fty2>(_Fnarg)) {}

template <class... _ArgTypes>
class _Packaged_state<void(_ArgTypes...)>
: public _Associated_state<int> { // class for managing associated asynchronous state for packaged_task
public:
using _Mybase = _Associated_state<int>;
using _Mydel = typename _Mybase::_Mydel;
#if _HAS_FUNCTION_ALLOCATOR_SUPPORT
template <class _Alloc>
_Packaged_state(const _Function_type& _Fnarg, const _Alloc& _Al, _Mydel* _Dp)
: _Mybase(_Dp), _Fn(allocator_arg, _Al, _Fnarg) {}

template <class _Fty2>
_Packaged_state(_Fty2&& _Fnarg) : _Fn(_STD forward<_Fty2>(_Fnarg)) {}
template <class _Alloc>
_Packaged_state(_Function_type&& _Fnarg, const _Alloc& _Al, _Mydel* _Dp)
: _Mybase(_Dp), _Fn(allocator_arg, _Al, _STD move(_Fnarg)) {}

#if _HAS_FUNCTION_ALLOCATOR_SUPPORT
template <class _Fty2, class _Alloc>
template <class _Fty2, class _Alloc, enable_if_t<!is_same_v<_Remove_cvref_t<_Fty2>, _Function_type>, int> = 0>
_Packaged_state(_Fty2&& _Fnarg, const _Alloc& _Al, _Mydel* _Dp)
: _Mybase(_Dp), _Fn(allocator_arg, _Al, _STD forward<_Fty2>(_Fnarg)) {}
: _Mybase(_Dp), _Fn(_Secret_copyability_ignoring_tag{}, allocator_arg, _Al, _STD forward<_Fty2>(_Fnarg)) {}
#endif // _HAS_FUNCTION_ALLOCATOR_SUPPORT

void _Call_deferred(_ArgTypes... _Args) { // set deferred call
_TRY_BEGIN
// call function object and catch exceptions
_Fn(_STD forward<_ArgTypes>(_Args)...);
this->_Set_value(1, true);
if constexpr (is_same_v<_Ret, void>) {
_Fn(_STD forward<_ArgTypes>(_Args)...);
this->_Set_value(1, true);
} else if constexpr (is_lvalue_reference_v<_Ret>) {
this->_Set_value(_STD addressof(_Fn(_STD forward<_ArgTypes>(_Args)...)), true);
} else {
this->_Set_value(_Fn(_STD forward<_ArgTypes>(_Args)...), true);
}
_CATCH_ALL
// function object threw exception; record result
this->_Set_exception(_STD current_exception(), true);
Expand All @@ -580,20 +526,29 @@ public:
void _Call_immediate(_ArgTypes... _Args) { // call function object
_TRY_BEGIN
// call function object and catch exceptions
_Fn(_STD forward<_ArgTypes>(_Args)...);
this->_Set_value(1, false);
if constexpr (is_same_v<_Ret, void>) {
_Fn(_STD forward<_ArgTypes>(_Args)...);
this->_Set_value(1, false);
} else if constexpr (is_lvalue_reference_v<_Ret>) {
this->_Set_value(_STD addressof(_Fn(_STD forward<_ArgTypes>(_Args)...)), false);
} else {
this->_Set_value(_Fn(_STD forward<_ArgTypes>(_Args)...), false);
}
_CATCH_ALL
// function object threw exception; record result
this->_Set_exception(_STD current_exception(), false);
_CATCH_END
}

const auto& _Get_fn() const {
const auto& _Get_fn() const& {
return _Fn;
}
auto&& _Get_fn() && noexcept {
return _STD move(_Fn);
}

private:
function<void(_ArgTypes...)> _Fn;
_Function_type _Fn;
};

template <class _Ty, class _Alloc>
Expand Down Expand Up @@ -1235,21 +1190,6 @@ void swap(promise<_Ty>& _Left, promise<_Ty>& _Right) noexcept {
_Left.swap(_Right);
}

template <class _Fret>
struct _P_arg_type { // type for functions returning T
using type = _Fret;
};

template <class _Fret>
struct _P_arg_type<_Fret&> { // type for functions returning reference to T
using type = _Fret*;
};

template <>
struct _P_arg_type<void> { // type for functions returning void
using type = int;
};

_EXPORT_STD template <class>
class packaged_task; // not defined

Expand All @@ -1266,7 +1206,10 @@ public:
packaged_task() = default;

template <class _Fty2, enable_if_t<!is_same_v<_Remove_cvref_t<_Fty2>, packaged_task>, int> = 0>
explicit packaged_task(_Fty2&& _Fnarg) : _MyPromise(new _MyStateType(_STD forward<_Fty2>(_Fnarg))) {}
explicit packaged_task(_Fty2&& _Fnarg) : _MyPromise(new _MyStateType(_STD forward<_Fty2>(_Fnarg))) {
static_assert(_Is_invocable_r<_Ret, decay_t<_Fty2>&, _ArgTypes...>::value, // per LWG-4154
"The function object must be callable with _ArgTypes... and return _Ret (N4988 [futures.task.members]/3).");
}

packaged_task(packaged_task&&) noexcept = default;

Expand All @@ -1275,7 +1218,10 @@ public:
#if _HAS_FUNCTION_ALLOCATOR_SUPPORT
template <class _Fty2, class _Alloc, enable_if_t<!is_same_v<_Remove_cvref_t<_Fty2>, packaged_task>, int> = 0>
packaged_task(allocator_arg_t, const _Alloc& _Al, _Fty2&& _Fnarg)
: _MyPromise(_STD _Make_packaged_state<_MyStateType>(_STD forward<_Fty2>(_Fnarg), _Al)) {}
: _MyPromise(_STD _Make_packaged_state<_MyStateType>(_STD forward<_Fty2>(_Fnarg), _Al)) {
static_assert(_Is_invocable_r<_Ret, decay_t<_Fty2>&, _ArgTypes...>::value, // per LWG-4154
"The function object must be callable with _ArgTypes... and return _Ret (N4140 [futures.task.members]/2).");
}
#endif // _HAS_FUNCTION_ALLOCATOR_SUPPORT

~packaged_task() noexcept {
Expand Down Expand Up @@ -1319,9 +1265,9 @@ public:
}

void reset() { // reset to newly constructed state
_MyStateManagerType& _State = _MyPromise._Get_state_for_set();
_MyStateType* _MyState = static_cast<_MyStateType*>(_State._Ptr());
_MyPromiseType _New_promise(new _MyStateType(_MyState->_Get_fn()));
_MyStateManagerType& _State_mgr = _MyPromise._Get_state_for_set();
_MyStateType& _MyState = *static_cast<_MyStateType*>(_State_mgr._Ptr());
_MyPromiseType _New_promise(new _MyStateType(_STD move(_MyState)._Get_fn()));
_MyPromise._Get_state()._Abandon();
_MyPromise._Swap(_New_promise);
}
Expand Down
11 changes: 11 additions & 0 deletions tests/std/tests/Dev10_561430_list_and_tree_leaks/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@ int main() {
assert(f.get() == 1234);
}

// Also test GH-321: "<future>: packaged_task can't be constructed from a move-only lambda"
{
packaged_task<int()> pt(allocator_arg, Mallocator<int>(), [uptr = make_unique<int>(172)] { return *uptr; });

future<int> f = pt.get_future();

pt();

assert(f.get() == 172);
}

{
int n = 4096;

Expand Down
15 changes: 15 additions & 0 deletions tests/std/tests/Dev11_0235721_async_and_packaged_task/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ void test_DevDiv_725337() {
int i = 1729;
auto ref_lambda = [&]() -> int& { return i; };

// GH-321: "<future>: packaged_task can't be constructed from a move-only lambda"
auto move_only_lambda = [uptr = make_unique<int>(42)] { return *uptr; };

{
packaged_task<int()> pt1([] { return 19937; });
future<int> f = pt1.get_future();
Expand All @@ -90,6 +93,18 @@ void test_DevDiv_725337() {
assert(f.get() == 19937);
}

{
packaged_task<int()> pt1(move(move_only_lambda));
future<int> f = pt1.get_future();
packaged_task<int()> pt2(move(pt1));
packaged_task<int()> pt3;
pt3 = move(pt2);
assert(f.wait_for(0s) == future_status::timeout);
pt3();
assert(f.wait_for(0s) == future_status::ready);
assert(f.get() == 42);
}

{
packaged_task<int&()> pt1(ref_lambda);
future<int&> f = pt1.get_future();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ void test_function() {

void test_packaged_task() {
packaged_task<void(validator)>{};
packaged_task<void(validator)>{nullptr};
packaged_task<void(validator)>{simple_identity{}};
packaged_task<void(validator)>{simple_large_identity{}};

Expand Down
Loading

0 comments on commit d08e31c

Please sign in to comment.