Skip to content

Commit

Permalink
add SharedFuture
Browse files Browse the repository at this point in the history
  • Loading branch information
techleeksnap committed Jul 19, 2024
1 parent a97afc7 commit ae855ee
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
build --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder
build --cxxopt=-std=c++17 --cxxopt=-fcoroutines-ts --host_cxxopt=-std=c++17 --host_cxxopt=-fcoroutines-ts --incompatible_java_common_parameters=false --define=android_dexmerger_tool=d8_dexmerger --define=android_incremental_dexing_tool=d8_dexbuilder --nouse_workers_with_dexbuilder
3 changes: 2 additions & 1 deletion support-lib/cpp/Future.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ class Future {
constexpr bool await_ready() const noexcept {
return false;
}
bool await_suspend(detail::CoroutineHandle<ConcretePromise> finished) const noexcept {
template <typename P>
bool await_suspend(detail::CoroutineHandle<P> finished) const noexcept {
auto& promise_type = finished.promise();
if (*promise_type._result) {
if constexpr (std::is_void_v<T>) {
Expand Down
138 changes: 138 additions & 0 deletions support-lib/cpp/SharedFuture.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/**
* Copyright 2021 Snap, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "Future.hpp"

#include <memory>
#include <optional>
#include <type_traits>
#include <variant>
#include <vector>

namespace djinni {

// SharedFuture is a wrapper around djinni::Future to allow multiple consumers (i.e. like std::shared_future)
// The API is designed to be similar to djinni::Future.
template<typename T>
class SharedFuture {
public:
// Create SharedFuture from Future. Runtime error if the future is already consumed.
explicit SharedFuture(Future<T>&& future);

// Transform into Future<T>.
Future<T> toFuture() const {
if (await_ready()) {
co_return await_resume(); // return stored value directly
}
co_return co_await SharedFuture(*this); // retain copy during coroutine suspension
}

void wait() const {
return [this]() -> Future<void> { co_await *this; }().wait();
}

decltype(auto) get() const {
wait();
return await_resume();
}

// Transform the result of this future into a new future. The behavior is same as Future::then except that
// it doesn't consume the future, and can be called multiple times.
template<typename Func>
SharedFuture<std::remove_cv_t<std::remove_reference_t<std::invoke_result_t<Func, T>>>> then(Func transform) const {
co_return transform(co_await SharedFuture(*this)); // retain copy during coroutine suspension
}

// Overload for T = void or `transform` takes no arugment.
template<typename Func, typename = std::enable_if_t<!std::is_invocable_v<Func, T>>>
SharedFuture<std::remove_cv_t<std::remove_reference_t<std::invoke_result_t<Func>>>> then(Func transform) const {
co_await SharedFuture(*this); // retain copy during coroutine suspension
co_return transform();
}

// -- coroutine support implementation only; not intended externally --

bool await_ready() const {
std::scoped_lock lock(_sharedStates->mutex);
return _sharedStates->storedValue.has_value();
}

decltype(auto) await_resume() const {
if constexpr (!std::is_void_v<T>) {
return *_sharedStates->storedValue;
}
}

bool await_suspend(detail::CoroutineHandle<> h) const;

struct Promise : public Future<T>::promise_type {
SharedFuture<T> get_return_object() noexcept {
return SharedFuture(Future<T>::promise_type::get_return_object());
}
};
using promise_type = Promise;

private:
struct SharedStates {
std::recursive_mutex mutex;
std::optional<std::conditional_t<std::is_void_v<T>, std::monostate, T>> storedValue = std::nullopt;
std::vector<detail::CoroutineHandle<>> coroutineHandles;
};
// Use a shared_ptr to allow copying SharedFuture.
std::shared_ptr<SharedStates> _sharedStates = std::make_shared<SharedStates>();
};

// CTAD deduction guide to construct from Future directly.
template<typename T>
SharedFuture(Future<T>&&) -> SharedFuture<T>;

// ------------------ Implementation ------------------

template<typename T>
SharedFuture<T>::SharedFuture(Future<T>&& future) {
// `future` will invoke all continuations when it is ready.
future.then([sharedStates = _sharedStates](auto futureResult) {
std::vector toCall = [&] {
std::scoped_lock lock(sharedStates->mutex);
if constexpr (std::is_void_v<T>) {
sharedStates->storedValue.emplace();
} else {
sharedStates->storedValue = futureResult.get();
}
return std::move(sharedStates->coroutineHandles);
}();
for (auto& handle : toCall) {
handle();
}
});
}

template<typename T>
bool SharedFuture<T>::await_suspend(detail::CoroutineHandle<> h) const {
{
std::unique_lock lock(_sharedStates->mutex);
if (!_sharedStates->storedValue) {
_sharedStates->coroutineHandles.push_back(std::move(h));
return true;
}
}
h();
return true;
}

} // namespace djinni
1 change: 1 addition & 0 deletions test-suite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ objc_library(
copts = [
"-ObjC++",
"-std=c++17",
"-fcoroutines-ts"
],
srcs = glob([
"generated-src/objc/**/*.mm",
Expand Down
61 changes: 61 additions & 0 deletions test-suite/handwritten-src/objc/tests/DBSharedFutureTest.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#import <Foundation/Foundation.h>
#import <XCTest/XCTest.h>

#include "../../../support-lib/cpp/SharedFuture.hpp"

@interface DBSharedFutureTest : XCTestCase
@end

@implementation DBSharedFutureTest

- (void)setUp
{
[super setUp];
}

- (void)tearDown
{
[super tearDown];
}

- (void)testCreateFuture
{
djinni::SharedFuture<int> resolvedInt(djinni::Promise<int>::resolve(42));
XCTAssertEqual(resolvedInt.get(), 42);

djinni::Promise<NSString*> strPromise;
djinni::SharedFuture futureString(strPromise.getFuture());

strPromise.setValue(@"foo");
XCTAssertEqualObjects(futureString.get(), @"foo");
}

- (void)testThen
{
djinni::Promise<int> intPromise;
djinni::SharedFuture<int> futureInt(intPromise.getFuture());

auto transformedInt = futureInt.then([](int i) { return 2 * i; });

intPromise.setValue(42);
XCTAssertEqual(transformedInt.get(), 84);

// Also verify multiple consumers and chaining.
auto transformedString = futureInt.then([](int i) { return std::to_string(i); });
auto futurePlusOneTimesTwo = futureInt.then([](int i) { return i + 1; }).then([](int i) { return 2 * i; });
auto futureStringLen = transformedString.then([](const std::string& s) { return s.length(); }).toFuture();

XCTAssertEqual(transformedString.get(), std::string("42"));
XCTAssertEqual(futurePlusOneTimesTwo.get(), (42 + 1) * 2);
XCTAssertEqual(futureStringLen.get(), 2);

XCTAssertEqual(futureInt.get(), 42);

auto voidFuture = transformedString.then([]() {});
voidFuture.wait();

auto intFuture2 = voidFuture.then([]() { return 43; });
XCTAssertEqual(intFuture2.get(), 43);
}

@end

0 comments on commit ae855ee

Please sign in to comment.