Skip to content

Commit

Permalink
Merge pull request #183 from techleeksnap/sharedfuture
Browse files Browse the repository at this point in the history
Add SharedFuture
  • Loading branch information
li-feng-sc authored Aug 1, 2024
2 parents a97afc7 + ce6add1 commit b604359
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder
build --cxxopt=-std=c++17 --cxxopt=-fcoroutines-ts --host_cxxopt=-std=c++17 --host_cxxopt=-fcoroutines-ts --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder
9 changes: 5 additions & 4 deletions support-lib/cpp/Future.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ class Future {
return true;
}

template<typename ConcretePromise>
struct PromiseTypeBase {
Promise<T> _promise;
std::optional<djinni::expected<T, std::exception_ptr>> _result{};
Expand All @@ -379,7 +378,9 @@ class Future {
constexpr bool await_ready() const noexcept {
return false;
}
bool await_suspend(detail::CoroutineHandle<ConcretePromise> finished) const noexcept {
template <typename P>
bool await_suspend(detail::CoroutineHandle<P> finished) const noexcept {
static_assert(std::is_convertible_v<P*, PromiseTypeBase*>);
auto& promise_type = finished.promise();
if (*promise_type._result) {
if constexpr (std::is_void_v<T>) {
Expand All @@ -406,7 +407,7 @@ class Future {
}
};

struct PromiseType: PromiseTypeBase<PromiseType>{
struct PromiseType: PromiseTypeBase {
template <typename V, typename = std::enable_if_t<std::is_convertible_v<V, T>>>
void return_value(V&& value) {
this->_result.emplace(std::forward<V>(value));
Expand All @@ -424,7 +425,7 @@ class Future {

#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT)
template<>
struct Future<void>::PromiseType : PromiseTypeBase<PromiseType> {
struct Future<void>::PromiseType : PromiseTypeBase {
void return_void() {
_result.emplace();
}
Expand Down
162 changes: 162 additions & 0 deletions support-lib/cpp/SharedFuture.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/**
* Copyright 2021 Snap, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "Future.hpp"

#if defined(DJINNI_FUTURE_HAS_COROUTINE_SUPPORT)

#include <memory>
#include <optional>
#include <type_traits>
#include <variant>
#include <vector>

namespace djinni {

// SharedFuture is a wrapper around djinni::Future to allow multiple consumers (i.e. like std::shared_future)
// The API is designed to be similar to djinni::Future.
template<typename T>
class SharedFuture {
public:
// Create SharedFuture from Future. Runtime error if the future is already consumed.
explicit SharedFuture(Future<T>&& future);

// Transform into Future<T>.
Future<T> toFuture() const {
if (await_ready()) {
co_return await_resume(); // return stored value directly
}
co_return co_await SharedFuture(*this); // retain copy during coroutine suspension
}

void wait() const {
waitIgnoringExceptions().wait();
}

decltype(auto) get() const {
wait();
return await_resume();
}

template <typename Func>
using ResultT = std::remove_cv_t<std::remove_reference_t<std::invoke_result_t<Func, const SharedFuture<T>&>>>;

// Transform the result of this future into a new future. The behavior is same as Future::then except that
// it doesn't consume the future, and can be called multiple times.
template<typename Func>
Future<ResultT<Func>> then(Func transform) const {
auto cpy = SharedFuture(*this); // retain copy during coroutine suspension
co_await cpy.waitIgnoringExceptions();
co_return transform(cpy);
}

// Same as above but returns SharedFuture.
template<typename Func>
SharedFuture<ResultT<Func>> thenShared(Func transform) const {
return SharedFuture<ResultT<Func>>(then(std::move(transform)));
}

// -- coroutine support implementation only; not intended externally --

bool await_ready() const {
std::scoped_lock lock(_sharedStates->mutex);
return _sharedStates->storedValue.has_value();
}

decltype(auto) await_resume() const {
if (!*_sharedStates->storedValue) {
std::rethrow_exception(_sharedStates->storedValue->error());
}
if constexpr (!std::is_void_v<T>) {
return const_cast<const T &>(_sharedStates->storedValue->value());
}
}

bool await_suspend(detail::CoroutineHandle<> h) const;

struct Promise : public Future<T>::promise_type {
SharedFuture<T> get_return_object() noexcept {
return SharedFuture(Future<T>::promise_type::get_return_object());
}
};
using promise_type = Promise;

private:
Future<void> waitIgnoringExceptions() const {
try {
co_await *this;
} catch (...) {
// Ignore exceptions.
}
}

struct SharedStates {
std::recursive_mutex mutex;
std::optional<djinni::expected<T, std::exception_ptr>> storedValue = std::nullopt;
std::vector<detail::CoroutineHandle<>> coroutineHandles;
};
// Use a shared_ptr to allow copying SharedFuture.
std::shared_ptr<SharedStates> _sharedStates = std::make_shared<SharedStates>();
};

// CTAD deduction guide to construct from Future directly.
template<typename T>
SharedFuture(Future<T>&&) -> SharedFuture<T>;

// ------------------ Implementation ------------------

template<typename T>
SharedFuture<T>::SharedFuture(Future<T>&& future) {
// `future` will invoke all continuations when it is ready.
future.then([sharedStates = _sharedStates](auto futureResult) {
std::vector toCall = [&] {
std::scoped_lock lock(sharedStates->mutex);
try {
if constexpr (std::is_void_v<T>) {
futureResult.get();
sharedStates->storedValue.emplace();
} else {
sharedStates->storedValue = futureResult.get();
}
} catch (...) {
sharedStates->storedValue = make_unexpected(std::current_exception());
}
return std::move(sharedStates->coroutineHandles);
}();
for (auto& handle : toCall) {
handle();
}
});
}

template<typename T>
bool SharedFuture<T>::await_suspend(detail::CoroutineHandle<> h) const {
{
std::unique_lock lock(_sharedStates->mutex);
if (!_sharedStates->storedValue) {
_sharedStates->coroutineHandles.push_back(std::move(h));
return true;
}
}
h();
return true;
}

} // namespace djinni

#endif
1 change: 1 addition & 0 deletions test-suite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ objc_library(
copts = [
"-ObjC++",
"-std=c++17",
"-fcoroutines-ts"
],
srcs = glob([
"generated-src/objc/**/*.mm",
Expand Down
91 changes: 91 additions & 0 deletions test-suite/handwritten-src/objc/tests/DBSharedFutureTest.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#import <Foundation/Foundation.h>
#import <XCTest/XCTest.h>

#include "../../../support-lib/cpp/SharedFuture.hpp"

@interface DBSharedFutureTest : XCTestCase
@end

@implementation DBSharedFutureTest

#ifdef DJINNI_FUTURE_HAS_COROUTINE_SUPPORT

- (void)setUp
{
[super setUp];
}

- (void)tearDown
{
[super tearDown];
}

- (void)testCreateFuture
{
djinni::SharedFuture<int> resolvedInt(djinni::Promise<int>::resolve(42));
XCTAssertEqual(resolvedInt.get(), 42);

djinni::Promise<NSString*> strPromise;
djinni::SharedFuture futureString(strPromise.getFuture());

strPromise.setValue(@"foo");
XCTAssertEqualObjects(futureString.get(), @"foo");
}

- (void)testThen
{
djinni::Promise<int> intPromise;
djinni::SharedFuture<int> futureInt(intPromise.getFuture());

auto transformedInt = futureInt.thenShared([](const auto& resolved) { return 2 * resolved.get(); });

intPromise.setValue(42);
XCTAssertEqual(transformedInt.get(), 84);

// Also verify multiple consumers and chaining.
auto transformedString = futureInt.thenShared([](const auto& resolved) { return std::to_string(resolved.get()); });
auto futurePlusOneTimesTwo = futureInt.then([](auto resolved) { return resolved.get() + 1; }).then([](auto resolved) {
return 2 * resolved.get();
});
auto futureStringLen = transformedString.then([](auto resolved) { return resolved.get().length(); });

XCTAssertEqual(transformedString.get(), std::string("42"));
XCTAssertEqual(futurePlusOneTimesTwo.get(), (42 + 1) * 2);
XCTAssertEqual(futureStringLen.get(), 2);

XCTAssertEqual(futureInt.get(), 42);

auto voidFuture = transformedString.thenShared([](auto) {});
voidFuture.wait();

auto intFuture2 = voidFuture.thenShared([](auto) { return 43; });
XCTAssertEqual(intFuture2.get(), 43);
}

- (void)testException
{
// Also verify exception handling.
djinni::Promise<int> intPromise;
djinni::SharedFuture<int> futureInt(intPromise.getFuture());

intPromise.setException(std::runtime_error("mocked"));

XCTAssertThrows(futureInt.get());

auto thenResult = futureInt.then([](auto resolved) { return resolved.get(); });
XCTAssertThrows(thenResult.get());

auto withExceptionHandling = futureInt.thenShared([](const auto& resolved) {
try {
return resolved.get();
} catch (...) {
return -1;
}
});
withExceptionHandling.wait();
XCTAssertEqual(withExceptionHandling.get(), -1);
}

#endif

@end

0 comments on commit b604359

Please sign in to comment.