Skip to content

Commit

Permalink
Make tensorstore::internal::RateLimiter a pure virtual interface.
Browse files Browse the repository at this point in the history
Convert Start/Admit to use RateLimiterNode* and static_cast.

Moves the linked-list initialization into RateLimiter implementation classes.
Minor additional cleanups

PiperOrigin-RevId: 684170253
Change-Id: I1ac93e8791f2c605a5192e6a6350f34e8ade1219
  • Loading branch information
laramiel authored and copybara-github committed Oct 9, 2024
1 parent 2def8c1 commit 1650362
Show file tree
Hide file tree
Showing 12 changed files with 186 additions and 87 deletions.
12 changes: 9 additions & 3 deletions tensorstore/internal/rate_limiter/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ tensorstore_cc_library(
name = "rate_limiter",
srcs = ["rate_limiter.cc"],
hdrs = ["rate_limiter.h"],
deps = ["//tensorstore/internal/container:intrusive_linked_list"],
)

tensorstore_cc_test(
name = "rate_limiter_test",
srcs = ["rate_limiter_test.cc"],
deps = [
"//tensorstore/internal/container:intrusive_linked_list",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/synchronization",
":rate_limiter",
"//tensorstore/internal:intrusive_ptr",
"@com_google_googletest//:gtest_main",
],
)

Expand Down
31 changes: 23 additions & 8 deletions tensorstore/internal/rate_limiter/admission_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,15 @@ namespace tensorstore {
namespace internal {

AdmissionQueue::AdmissionQueue(size_t limit)
: limit_(limit == 0 ? std::numeric_limits<size_t>::max() : limit) {}
: limit_(limit == 0 ? std::numeric_limits<size_t>::max() : limit) {
internal::intrusive_linked_list::Initialize(RateLimiterNodeAccessor{},
&head_);
}

AdmissionQueue::~AdmissionQueue() {
absl::MutexLock l(&mutex_);
assert(head_.next_ == &head_);
}

void AdmissionQueue::Admit(RateLimiterNode* node, RateLimiterNode::StartFn fn) {
assert(node->next_ == nullptr);
Expand All @@ -37,11 +45,12 @@ void AdmissionQueue::Admit(RateLimiterNode* node, RateLimiterNode::StartFn fn) {

{
absl::MutexLock lock(&mutex_);
if (in_flight_++ >= limit_) {
if (in_flight_ + 1 > limit_) {
internal::intrusive_linked_list::InsertBefore(RateLimiterNodeAccessor{},
&head_, node);
return;
}
in_flight_++;
}

RunStartFunction(node);
Expand All @@ -50,18 +59,24 @@ void AdmissionQueue::Admit(RateLimiterNode* node, RateLimiterNode::StartFn fn) {
void AdmissionQueue::Finish(RateLimiterNode* node) {
assert(node->next_ == nullptr);

absl::MutexLock lock(&mutex_);
in_flight_--;

// Typically this loop will admit only a single node at a time.
RateLimiterNode* next_node = nullptr;
{
absl::MutexLock lock(&mutex_);
in_flight_--;
while (true) {
next_node = head_.next_;
if (next_node == &head_) return;
if (in_flight_ + 1 > limit_) return;
in_flight_++;
internal::intrusive_linked_list::Remove(RateLimiterNodeAccessor{},
next_node);
}

// Next node gets a chance to run after clearing admission queue state.
RunStartFunction(next_node);
// Next node gets a chance to run after clearing admission queue state.
mutex_.Unlock();
RunStartFunction(next_node);
mutex_.Lock();
}
}

} // namespace internal
Expand Down
5 changes: 4 additions & 1 deletion tensorstore/internal/rate_limiter/admission_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class AdmissionQueue : public RateLimiter {
public:
/// Construct an AdmissionQueue with `limit` parallelism.
AdmissionQueue(size_t limit);
~AdmissionQueue() override = default;
~AdmissionQueue() override;

size_t limit() const { return limit_; }
size_t in_flight() const {
Expand All @@ -58,6 +58,9 @@ class AdmissionQueue : public RateLimiter {

private:
const size_t limit_;

mutable absl::Mutex mutex_;
RateLimiterNode head_ ABSL_GUARDED_BY(mutex_);
size_t in_flight_ ABSL_GUARDED_BY(mutex_) = 0;
};

Expand Down
34 changes: 21 additions & 13 deletions tensorstore/internal/rate_limiter/admission_queue_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,36 @@

namespace {

using ::tensorstore::Executor;
using ::tensorstore::ExecutorTask;
using ::tensorstore::internal::AdmissionQueue;
using ::tensorstore::internal::adopt_object_ref;
using ::tensorstore::internal::AtomicReferenceCount;
using ::tensorstore::internal::IntrusivePtr;
using ::tensorstore::internal::MakeIntrusivePtr;
using ::tensorstore::internal::RateLimiter;
using ::tensorstore::internal::RateLimiterNode;

struct Node : public RateLimiterNode, public AtomicReferenceCount<Node> {
AdmissionQueue* queue_;
/// This class holds a reference count on itself while held by a RateLimiter,
/// and upon start will call the `task_` function.
struct Task : public RateLimiterNode, public AtomicReferenceCount<Task> {
RateLimiter* rate_limiter_;
ExecutorTask task_;

Node(AdmissionQueue* queue, ExecutorTask task)
: queue_(queue), task_(std::move(task)) {}
Task(RateLimiter* rate_limiter, ExecutorTask task)
: rate_limiter_(rate_limiter), task_(std::move(task)) {
// NOTE: Do not call Admit in the constructor as the task may complete
// and try to delete self before MakeIntrusivePtr completes.
}

~Task() { rate_limiter_->Finish(this); }

~Node() { queue_->Finish(this); }
void Admit() {
intrusive_ptr_increment(this); // adopted by RateLimiterTask::Start.
rate_limiter_->Admit(this, &Task::Start);
}

static void Start(void* task) {
IntrusivePtr<Node> self(reinterpret_cast<Node*>(task), adopt_object_ref);
static void Start(RateLimiterNode* task) {
IntrusivePtr<Task> self(static_cast<Task*>(task),
tensorstore::internal::adopt_object_ref);
std::move(self->task_)();
}
};
Expand All @@ -59,10 +69,8 @@ TEST(AdmissionQueueTest, Basic) {

{
for (int i = 0; i < 100; i++) {
auto node = MakeIntrusivePtr<Node>(&queue, [&done] { done++; });

intrusive_ptr_increment(node.get()); // adopted by Node::Start.
queue.Admit(node.get(), &Node::Start);
auto task = MakeIntrusivePtr<Task>(&queue, [&done] { done++; });
task->Admit();
}
}

Expand Down
16 changes: 2 additions & 14 deletions tensorstore/internal/rate_limiter/rate_limiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,9 @@

#include <cassert>

#include "absl/synchronization/mutex.h"
#include "tensorstore/internal/container/intrusive_linked_list.h"

namespace tensorstore {
namespace internal {

RateLimiter::RateLimiter() {
absl::MutexLock l(&mutex_);
internal::intrusive_linked_list::Initialize(RateLimiterNodeAccessor{},
&head_);
}

RateLimiter::~RateLimiter() {
absl::MutexLock l(&mutex_);
assert(head_.next_ == &head_);
}

void RateLimiter::RunStartFunction(RateLimiterNode* node) {
// Next node gets a chance to run after clearing admission queue state.
RateLimiterNode::StartFn fn = node->start_fn_;
Expand All @@ -53,6 +39,8 @@ void NoRateLimiter::Admit(RateLimiterNode* node, RateLimiterNode::StartFn fn) {

void NoRateLimiter::Finish(RateLimiterNode* node) {
assert(node->next_ == nullptr);
assert(node->prev_ == nullptr);
assert(node->start_fn_ == nullptr);
}

} // namespace internal
Expand Down
25 changes: 9 additions & 16 deletions tensorstore/internal/rate_limiter/rate_limiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,20 @@
#ifndef TENSORSTORE_INTERNAL_RATE_LIMITER_RATE_LIMITER_H_
#define TENSORSTORE_INTERNAL_RATE_LIMITER_RATE_LIMITER_H_

#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "tensorstore/internal/container/intrusive_linked_list.h"

namespace tensorstore {
namespace internal {

// RateLimiter is an interface which supports rate-limiting for an operation.
// Pending operations use the `RateLimiterNode` base class, and are managed
// via `RateLimiter::Admit` and `RateLimiter::Finish` calls.
//
// Generally, a RateLimiterNode will also be reference counted, however neither
// the RateLimiterNode nor the RateLimiter class manage any reference counts.
// Callers should manage reference counts externally.
//
/// RateLimiter is an interface which supports rate-limiting for an operation.
/// Pending operations use the `RateLimiterNode` base class, and are managed
/// via `RateLimiter::Admit` and `RateLimiter::Finish` calls.
///
/// Generally, a RateLimiterNode will also be reference counted, however neither
/// the RateLimiterNode nor the RateLimiter class manage any reference counts.
/// Callers should manage reference counts externally.
struct RateLimiterNode {
using StartFn = void (*)(void*);
using StartFn = void (*)(RateLimiterNode*);

RateLimiterNode* next_ = nullptr;
RateLimiterNode* prev_ = nullptr;
Expand All @@ -44,8 +41,7 @@ using RateLimiterNodeAccessor = internal::intrusive_linked_list::MemberAccessor<
/// RateLimiter interface.
class RateLimiter {
public:
RateLimiter();
virtual ~RateLimiter();
virtual ~RateLimiter() = default;

/// Add a task to the rate limiter. Will arrange for `fn(node)` to be called
/// at some (possible future) point.
Expand All @@ -56,9 +52,6 @@ class RateLimiter {

protected:
static void RunStartFunction(RateLimiterNode* node);

mutable absl::Mutex mutex_;
RateLimiterNode head_ ABSL_GUARDED_BY(mutex_);
};

/// RateLimiter interface.
Expand Down
76 changes: 76 additions & 0 deletions tensorstore/internal/rate_limiter/rate_limiter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2024 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "tensorstore/internal/rate_limiter/rate_limiter.h"

#include <atomic>
#include <cstddef>
#include <type_traits>

#include <gtest/gtest.h>
#include "tensorstore/internal/intrusive_ptr.h"

namespace {

using ::tensorstore::internal::AtomicReferenceCount;
using ::tensorstore::internal::IntrusivePtr;
using ::tensorstore::internal::MakeIntrusivePtr;
using ::tensorstore::internal::NoRateLimiter;
using ::tensorstore::internal::RateLimiter;
using ::tensorstore::internal::RateLimiterNode;

/// This class holds a reference count on itself while held by a RateLimiter,
/// and upon start will call the `task_` function.
template <typename Fn>
struct RateLimiterTask : public AtomicReferenceCount<RateLimiterTask<Fn>>,
public RateLimiterNode {
RateLimiter* rate_limiter_;
Fn task_;

RateLimiterTask(RateLimiter* rate_limiter, Fn task)
: rate_limiter_(rate_limiter), task_(std::move(task)) {
// NOTE: Do not call Admit in the constructor as the task may complete
// and try to delete self before MakeIntrusivePtr completes.
}

~RateLimiterTask() { rate_limiter_->Finish(this); }

void Admit() {
intrusive_ptr_increment(this); // adopted by RateLimiterTask::Start.
rate_limiter_->Admit(this, &RateLimiterTask::Start);
}

static void Start(RateLimiterNode* task) {
IntrusivePtr<RateLimiterTask> self(static_cast<RateLimiterTask<Fn>*>(task),
tensorstore::internal::adopt_object_ref);
std::move(self->task_)();
}
};

TEST(AdmissionQueueTest, Basic) {
NoRateLimiter queue;
std::atomic<size_t> done{0};

auto increment = [&done] { done++; };
using Node = RateLimiterTask<std::remove_reference_t<decltype(increment)>>;

{
for (int i = 0; i < 100; i++) {
MakeIntrusivePtr<Node>(&queue, increment)->Admit();
}
}

EXPECT_EQ(100, done);
}
} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ struct Node : public RateLimiterNode, public AtomicReferenceCount<Node> {

~Node() { queue_->Finish(this); }

static void Start(void* task) {
IntrusivePtr<Node> self(reinterpret_cast<Node*>(task), adopt_object_ref);
static void Start(RateLimiterNode* task) {
IntrusivePtr<Node> self(static_cast<Node*>(task), adopt_object_ref);
std::move(self->task_)();
}
};
Expand Down
11 changes: 9 additions & 2 deletions tensorstore/internal/rate_limiter/token_bucket_rate_limiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,29 @@ TokenBucketRateLimiter::TokenBucketRateLimiter(double max_tokens)
max_tokens_(max_tokens),
start_time_(clock_()),
last_update_(start_time_),
allow_schedule_at_(true) {}
allow_schedule_at_(true) {
internal::intrusive_linked_list::Initialize(RateLimiterNodeAccessor{},
&head_);
}

TokenBucketRateLimiter::TokenBucketRateLimiter(
double max_tokens, std::function<absl::Time()> clock)
: clock_(std::move(clock)),
max_tokens_(max_tokens),
start_time_(clock_()),
last_update_(start_time_),
allow_schedule_at_(false) {}
allow_schedule_at_(false) {
internal::intrusive_linked_list::Initialize(RateLimiterNodeAccessor{},
&head_);
}

TokenBucketRateLimiter::~TokenBucketRateLimiter() {
absl::MutexLock l(&mutex_);
mutex_.Await(absl::Condition(
+[](TokenBucketRateLimiter* self) ABSL_EXCLUSIVE_LOCKS_REQUIRED(
self->mutex_) { return !self->scheduled_; },
this));
assert(head_.next_ == &head_);
}

void TokenBucketRateLimiter::Admit(RateLimiterNode* node,
Expand Down
3 changes: 3 additions & 0 deletions tensorstore/internal/rate_limiter/token_bucket_rate_limiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class TokenBucketRateLimiter : public RateLimiter {
void PerformWork() ABSL_LOCKS_EXCLUDED(mutex_);
void PerformWorkLocked() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_);

mutable absl::Mutex mutex_;
RateLimiterNode head_ ABSL_GUARDED_BY(mutex_);

// Intermediate state values.
std::function<absl::Time()> clock_;
const double max_tokens_;
Expand Down
Loading

0 comments on commit 1650362

Please sign in to comment.