Skip to content

Commit

Permalink
Add MaybeValid() to base::Callback
Browse files Browse the repository at this point in the history
This is a thread-safe validity check for WeakPtr-based Callbacks using the recently added WeakPtr::MaybeValid().

Bug: 730693
Change-Id: I174efb30bb16d2776e33ec64d48a913943e770a0
Reviewed-on: https://chromium-review.googlesource.com/1144208
Commit-Queue: Nicolas Ouellet-Payeur <nicolaso@chromium.org>
Reviewed-by: Jeremy Roman <jbroman@chromium.org>
Reviewed-by: Taiju Tsuiki <tzik@chromium.org>
Reviewed-by: Gabriel Charette <gab@chromium.org>
Cr-Commit-Position: refs/heads/master@{#579061}
  • Loading branch information
Nicolas Ouellet-payeur authored and Commit Bot committed Jul 30, 2018
1 parent 4531d6f commit 40f8e9a
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 11 deletions.
51 changes: 45 additions & 6 deletions base/bind_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -736,21 +736,42 @@ std::enable_if_t<!FunctorTraits<Functor>::is_nullable, bool> IsNull(

// Used by ApplyCancellationTraits below.
template <typename Functor, typename BoundArgsTuple, size_t... indices>
bool ApplyCancellationTraitsImpl(const Functor& functor,
const BoundArgsTuple& bound_args,
std::index_sequence<indices...>) {
bool ApplyCancellationTraitsIsCancelledImpl(const Functor& functor,
const BoundArgsTuple& bound_args,
std::index_sequence<indices...>) {
return CallbackCancellationTraits<Functor, BoundArgsTuple>::IsCancelled(
functor, std::get<indices>(bound_args)...);
}

// Relays |base| to corresponding CallbackCancellationTraits<>::Run(). Returns
// true if the callback |base| represents is canceled.
template <typename BindStateType>
bool ApplyCancellationTraits(const BindStateBase* base) {
bool ApplyCancellationTraitsIsCancelled(const BindStateBase* base) {
const BindStateType* storage = static_cast<const BindStateType*>(base);
static constexpr size_t num_bound_args =
std::tuple_size<decltype(storage->bound_args_)>::value;
return ApplyCancellationTraitsImpl(
return ApplyCancellationTraitsIsCancelledImpl(
storage->functor_, storage->bound_args_,
std::make_index_sequence<num_bound_args>());
};

// Used by ApplyCancellationTraits below.
template <typename Functor, typename BoundArgsTuple, size_t... indices>
bool ApplyCancellationTraitsMaybeValidImpl(const Functor& functor,
const BoundArgsTuple& bound_args,
std::index_sequence<indices...>) {
return CallbackCancellationTraits<Functor, BoundArgsTuple>::MaybeValid(
functor, std::get<indices>(bound_args)...);
}

// Relays |base| to corresponding CallbackCancellationTraits<>::Run(). Returns
// false if the callback |base| represents is guaranteed to be cancelled.
template <typename BindStateType>
bool ApplyCancellationTraitsMaybeValid(const BindStateBase* base) {
const BindStateType* storage = static_cast<const BindStateType*>(base);
static constexpr size_t num_bound_args =
std::tuple_size<decltype(storage->bound_args_)>::value;
return ApplyCancellationTraitsMaybeValidImpl(
storage->functor_, storage->bound_args_,
std::make_index_sequence<num_bound_args>());
};
Expand Down Expand Up @@ -788,7 +809,8 @@ struct BindState final : BindStateBase {
ForwardBoundArgs&&... bound_args)
: BindStateBase(invoke_func,
&Destroy,
&ApplyCancellationTraits<BindState>),
&ApplyCancellationTraitsIsCancelled<BindState>,
&ApplyCancellationTraitsMaybeValid<BindState>),
functor_(std::forward<ForwardFunctor>(functor)),
bound_args_(std::forward<ForwardBoundArgs>(bound_args)...) {
DCHECK(!IsNull(functor_));
Expand Down Expand Up @@ -951,6 +973,13 @@ struct CallbackCancellationTraits<
const Args&...) {
return !receiver;
}

template <typename Receiver, typename... Args>
static bool MaybeValid(const Functor&,
const Receiver& receiver,
const Args&...) {
return receiver.MaybeValid();
}
};

// Specialization for a nested bind.
Expand All @@ -963,6 +992,11 @@ struct CallbackCancellationTraits<OnceCallback<Signature>,
static bool IsCancelled(const Functor& functor, const BoundArgs&...) {
return functor.IsCancelled();
}

template <typename Functor>
static bool MaybeValid(const Functor& functor, const BoundArgs&...) {
return functor.MaybeValid();
}
};

template <typename Signature, typename... BoundArgs>
Expand All @@ -974,6 +1008,11 @@ struct CallbackCancellationTraits<RepeatingCallback<Signature>,
static bool IsCancelled(const Functor& functor, const BoundArgs&...) {
return functor.IsCancelled();
}

template <typename Functor>
static bool MaybeValid(const Functor& functor, const BoundArgs&...) {
return functor.MaybeValid();
}
};

// Returns a RunType of bound functor.
Expand Down
17 changes: 14 additions & 3 deletions base/callback_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ bool ReturnFalse(const BindStateBase*) {
return false;
}

bool ReturnTrue(const BindStateBase*) {
return true;
}

} // namespace

void BindStateBaseRefCountTraits::Destruct(const BindStateBase* bind_state) {
Expand All @@ -23,15 +27,17 @@ void BindStateBaseRefCountTraits::Destruct(const BindStateBase* bind_state) {

BindStateBase::BindStateBase(InvokeFuncStorage polymorphic_invoke,
void (*destructor)(const BindStateBase*))
: BindStateBase(polymorphic_invoke, destructor, &ReturnFalse) {
: BindStateBase(polymorphic_invoke, destructor, &ReturnFalse, &ReturnTrue) {
}

BindStateBase::BindStateBase(InvokeFuncStorage polymorphic_invoke,
void (*destructor)(const BindStateBase*),
bool (*is_cancelled)(const BindStateBase*))
bool (*is_cancelled)(const BindStateBase*),
bool (*maybe_valid)(const BindStateBase*))
: polymorphic_invoke_(polymorphic_invoke),
destructor_(destructor),
is_cancelled_(is_cancelled) {}
is_cancelled_(is_cancelled),
maybe_valid_(maybe_valid) {}

CallbackBase& CallbackBase::operator=(CallbackBase&& c) noexcept = default;
CallbackBase::CallbackBase(const CallbackBaseCopyable& c)
Expand Down Expand Up @@ -61,6 +67,11 @@ bool CallbackBase::IsCancelled() const {
return bind_state_->IsCancelled();
}

bool CallbackBase::MaybeValid() const {
DCHECK(bind_state_);
return bind_state_->MaybeValid();
}

bool CallbackBase::EqualsInternal(const CallbackBase& other) const {
return bind_state_ == other.bind_state_;
}
Expand Down
15 changes: 14 additions & 1 deletion base/callback_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class BASE_EXPORT BindStateBase
void (*destructor)(const BindStateBase*));
BindStateBase(InvokeFuncStorage polymorphic_invoke,
void (*destructor)(const BindStateBase*),
bool (*is_cancelled)(const BindStateBase*));
bool (*is_cancelled)(const BindStateBase*),
bool (*maybe_valid)(const BindStateBase*));

~BindStateBase() = default;

Expand All @@ -76,6 +77,8 @@ class BASE_EXPORT BindStateBase
return is_cancelled_(this);
}

bool MaybeValid() const { return maybe_valid_(this); }

// In C++, it is safe to cast function pointers to function pointers of
// another type. It is not okay to use void*. We create a InvokeFuncStorage
// that that can store our function pointer, and then cast it back to
Expand All @@ -84,7 +87,9 @@ class BASE_EXPORT BindStateBase

// Pointer to a function that will properly destroy |this|.
void (*destructor_)(const BindStateBase*);

bool (*is_cancelled_)(const BindStateBase*);
bool (*maybe_valid_)(const BindStateBase*);

DISALLOW_COPY_AND_ASSIGN(BindStateBase);
};
Expand All @@ -110,8 +115,16 @@ class BASE_EXPORT CallbackBase {

// Returns true if the callback invocation will be nop due to an cancellation.
// It's invalid to call this on uninitialized callback.
//
// Must be called on the Callback's destination sequence.
bool IsCancelled() const;

// If this returns false, the callback invocation will be a nop due to a
// cancellation. This may(!) still return true, even on a cancelled callback.
//
// This function is thread-safe.
bool MaybeValid() const;

// Returns the Callback into an uninitialized state.
void Reset();

Expand Down
62 changes: 61 additions & 1 deletion base/callback_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "base/callback_helpers.h"
#include "base/callback_internal.h"
#include "base/memory/ref_counted.h"
#include "base/test/test_timeouts.h"
#include "base/threading/thread.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace base {
Expand All @@ -22,7 +24,8 @@ void NopInvokeFunc() {}
// chance of colliding with another instantiation and breaking the
// one-definition-rule.
struct FakeBindState : internal::BindStateBase {
FakeBindState() : BindStateBase(&NopInvokeFunc, &Destroy, &IsCancelled) {}
FakeBindState()
: BindStateBase(&NopInvokeFunc, &Destroy, &IsCancelled, &MaybeValid) {}

private:
~FakeBindState() = default;
Expand All @@ -32,6 +35,7 @@ struct FakeBindState : internal::BindStateBase {
static bool IsCancelled(const internal::BindStateBase*) {
return false;
}
static bool MaybeValid(const internal::BindStateBase*) { return true; }
};

namespace {
Expand Down Expand Up @@ -152,6 +156,62 @@ TEST_F(CallbackTest, NullAfterMoveRun) {
ASSERT_FALSE(cb3);
}

TEST_F(CallbackTest, MaybeValidReturnsTrue) {
Callback<void()> cb(BindRepeating([]() {}));
// By default, MaybeValid() just returns true all the time.
EXPECT_TRUE(cb.MaybeValid());
cb.Run();
EXPECT_TRUE(cb.MaybeValid());
}

// WeakPtr detection in BindRepeating() requires a method, not just any
// function.
class ClassWithAMethod {
public:
void TheMethod() {}
};

TEST_F(CallbackTest, MaybeValidInvalidateWeakPtrsOnSameSequence) {
ClassWithAMethod obj;
WeakPtrFactory<ClassWithAMethod> factory(&obj);
WeakPtr<ClassWithAMethod> ptr = factory.GetWeakPtr();

Callback<void()> cb(BindRepeating(&ClassWithAMethod::TheMethod, ptr));
EXPECT_TRUE(cb.MaybeValid());

factory.InvalidateWeakPtrs();
// MaybeValid() should be false because InvalidateWeakPtrs() was called on
// the same thread.
EXPECT_FALSE(cb.MaybeValid());
}

TEST_F(CallbackTest, MaybeValidInvalidateWeakPtrsOnOtherSequence) {
ClassWithAMethod obj;
WeakPtrFactory<ClassWithAMethod> factory(&obj);
WeakPtr<ClassWithAMethod> ptr = factory.GetWeakPtr();

Callback<void()> cb(BindRepeating(&ClassWithAMethod::TheMethod, ptr));
EXPECT_TRUE(cb.MaybeValid());

Thread other_thread("other_thread");
other_thread.StartAndWaitForTesting();
other_thread.task_runner()->PostTask(
FROM_HERE,
BindOnce(
[](Callback<void()> cb) {
// Check that MaybeValid() _eventually_ returns false.
const TimeDelta timeout = TestTimeouts::tiny_timeout();
const TimeTicks begin = TimeTicks::Now();
while (cb.MaybeValid() && (TimeTicks::Now() - begin) < timeout)
PlatformThread::YieldCurrentThread();
EXPECT_FALSE(cb.MaybeValid());
},
cb));
factory.InvalidateWeakPtrs();
// |other_thread|'s destructor will join, ensuring we wait for the task to be
// run.
}

class CallbackOwner : public base::RefCounted<CallbackOwner> {
public:
explicit CallbackOwner(bool* deleted) {
Expand Down
3 changes: 3 additions & 0 deletions third_party/blink/renderer/platform/heap/persistent.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class PersistentBase {
return *raw_;
}
explicit operator bool() const { return raw_; }
// TODO(https://crbug.com/653394): Consider returning a thread-safe best
// guess of validity.
bool MaybeValid() const { return true; }
operator T*() const {
CheckPointer();
return raw_;
Expand Down
8 changes: 8 additions & 0 deletions third_party/blink/renderer/platform/web_task_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ struct CallbackCancellationTraits<
const blink::TaskHandle& handle) {
return !handle.IsActive();
}

static bool MaybeValid(RunnerMethodType,
const base::WeakPtr<blink::TaskHandle::Runner>&,
const blink::TaskHandle& handle) {
// TODO(https://crbug.com/653394): Consider returning a thread-safe best
// guess of validity.
return true;
}
};

} // namespace base
Expand Down
9 changes: 9 additions & 0 deletions third_party/blink/renderer/platform/wtf/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ class ThreadCheckingCallbackWrapper<CallbackType, R(Args...)> {

bool IsCancelled() const { return callback_.IsCancelled(); }

bool MaybeValid() const { return callback_.MaybeValid(); }

private:
static R RunInternal(base::RepeatingCallback<R(Args...)>* callback,
Args&&... args) {
Expand Down Expand Up @@ -285,6 +287,13 @@ struct CallbackCancellationTraits<
const RunArgs&...) {
return receiver->IsCancelled();
}

template <typename Functor, typename Receiver, typename... RunArgs>
static bool MaybeValid(const Functor&,
const Receiver& receiver,
const RunArgs&...) {
return receiver->MaybeValid();
}
};

} // namespace base
Expand Down

0 comments on commit 40f8e9a

Please sign in to comment.