Skip to content

Commit

Permalink
Support providing custom rate-limiters
Browse files Browse the repository at this point in the history
Summary:
As title, I also noticed there were no unit tests for RateLimiter, so I added
some for MaxConcurrentRateLimiter

Reviewed By: Gownta

Differential Revision: D50399284

fbshipit-source-id: 2477aed15e54582cbd798ae84a421bc05bb5a16e
  • Loading branch information
Aaryaman Sagar authored and facebook-github-bot committed Nov 14, 2023
1 parent fa07253 commit ced26fe
Show file tree
Hide file tree
Showing 10 changed files with 293 additions and 131 deletions.
84 changes: 84 additions & 0 deletions folly/experimental/channels/MaxConcurrentRateLimiter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <folly/experimental/channels/MaxConcurrentRateLimiter.h>

namespace folly {
namespace channels {

class MaxConcurrentRateLimiter::Token : public RateLimiter::Token {
public:
explicit Token(
std::shared_ptr<MaxConcurrentRateLimiter> maxConcurrentRateLimiter)
: maxConcurrentRateLimiter_{std::move(maxConcurrentRateLimiter)} {}

Token(Token&&) = default;
Token& operator=(Token&&) = default;
Token(const Token&) = delete;
Token& operator=(const Token&) = delete;

~Token() override {
if (maxConcurrentRateLimiter_) {
maxConcurrentRateLimiter_->release();
}
}

private:
std::shared_ptr<MaxConcurrentRateLimiter> maxConcurrentRateLimiter_;
};

std::shared_ptr<MaxConcurrentRateLimiter> MaxConcurrentRateLimiter::create(
size_t maxConcurrent) {
return std::shared_ptr<MaxConcurrentRateLimiter>(
new MaxConcurrentRateLimiter(maxConcurrent));
}

MaxConcurrentRateLimiter::MaxConcurrentRateLimiter(size_t maxConcurrent)
: maxConcurrent_(maxConcurrent) {}

void MaxConcurrentRateLimiter::executeWhenReady(
folly::Function<void(std::unique_ptr<RateLimiter::Token>)> func,
Executor::KeepAlive<SequencedExecutor> executor) {
auto state = state_.wlock();
if (state->running < maxConcurrent_) {
CHECK(state->queue.empty());
state->running++;
executor->add(
[func = std::move(func),
token = std::make_unique<MaxConcurrentRateLimiter::Token>(
std::static_pointer_cast<MaxConcurrentRateLimiter>(
shared_from_this()))]() mutable { func(std::move(token)); });
} else {
state->queue.enqueue(QueueItem{std::move(func), std::move(executor)});
}
}

void MaxConcurrentRateLimiter::release() {
auto state = state_.wlock();
if (!state->queue.empty()) {
auto queueItem = state->queue.dequeue();
queueItem.executor->add(
[func = std::move(queueItem.func),
token = std::make_unique<MaxConcurrentRateLimiter::Token>(
std::static_pointer_cast<MaxConcurrentRateLimiter>(
shared_from_this()))]() mutable { func(std::move(token)); });
} else {
state->running--;
}
}

} // namespace channels
} // namespace folly
56 changes: 56 additions & 0 deletions folly/experimental/channels/MaxConcurrentRateLimiter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <folly/Synchronized.h>
#include <folly/concurrency/UnboundedQueue.h>
#include <folly/executors/SequencedExecutor.h>
#include <folly/experimental/channels/RateLimiter.h>

namespace folly {
namespace channels {

class MaxConcurrentRateLimiter : public RateLimiter {
public:
static std::shared_ptr<MaxConcurrentRateLimiter> create(size_t maxConcurrent);

void executeWhenReady(
folly::Function<void(std::unique_ptr<Token>)> func,
Executor::KeepAlive<SequencedExecutor> executor) override;

private:
class Token;
friend class Token;

explicit MaxConcurrentRateLimiter(size_t maxConcurrent);
void release();

struct QueueItem {
folly::Function<void(std::unique_ptr<Token>)> func;
Executor::KeepAlive<SequencedExecutor> executor;
};

struct State {
USPSCQueue<QueueItem, false /* MayBlock */, 6 /* LgSegmentSize */> queue;
size_t running{0};
};

const size_t maxConcurrent_;
folly::Synchronized<State> state_;
};
} // namespace channels
} // namespace folly
2 changes: 1 addition & 1 deletion folly/experimental/channels/MultiplexChannel-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ class MultiplexChannelProcessor : public IChannelCallback {
if (rateLimiter != nullptr) {
rateLimiter->executeWhenReady(
[this, func = std::move(func), executor = multiplexer_.getExecutor()](
RateLimiter::Token token) mutable {
std::unique_ptr<RateLimiter::Token> token) mutable {
folly::coro::co_invoke(
[this,
token = std::move(token),
Expand Down
65 changes: 0 additions & 65 deletions folly/experimental/channels/RateLimiter.cpp

This file was deleted.

68 changes: 31 additions & 37 deletions folly/experimental/channels/RateLimiter.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,45 @@

#pragma once

#include <folly/Synchronized.h>
#include <folly/concurrency/UnboundedQueue.h>
#include <folly/Function.h>
#include <folly/executors/SequencedExecutor.h>

namespace folly {
namespace channels {

/**
* A rate-limiter used by the channels framework to limit the number of
* in-flight requests.
*
* A default implementation is provided in MaxConcurrentRateLimiter.h but users
* can provide custom rate-limiters.
*/
class RateLimiter : public std::enable_shared_from_this<RateLimiter> {
public:
static std::shared_ptr<RateLimiter> create(size_t maxConcurrent);

class Token {
public:
explicit Token(std::shared_ptr<RateLimiter> rateLimiter);
~Token();

Token(const Token&) = delete;
Token& operator=(const Token&) = delete;
Token(Token&&) = default;
Token& operator=(Token&&) = default;

private:
std::shared_ptr<RateLimiter> rateLimiter_;
};

using QueuedFunc = folly::Function<void(Token)>;

void executeWhenReady(
QueuedFunc func, Executor::KeepAlive<SequencedExecutor> executor);

private:
explicit RateLimiter(size_t maxConcurrent);

struct QueueItem {
QueuedFunc func;
Executor::KeepAlive<SequencedExecutor> executor;
};

struct State {
USPSCQueue<QueueItem, false /* MayBlock */, 6 /* LgSegmentSize */> queue;
size_t running{0};
};
class Token;
virtual ~RateLimiter() = default;

/**
* Executes the given function when there is capacity available in the
* rate-limiter.
*
* The function is considered finished when the token is destroyed.
*/
virtual void executeWhenReady(
folly::Function<void(std::unique_ptr<Token>)> function,
Executor::KeepAlive<SequencedExecutor> executor) = 0;
};

const size_t maxConcurrent_;
folly::Synchronized<State> state_;
/**
* A token on destruction signals termination of the user provided function. So
* it's expected that a derived class override the destructor to provide the
* desired functionality. Or piggyback on destruction of the compiler generated
* overridden destructor.
*/
class RateLimiter::Token {
public:
virtual ~Token() = default;
};

} // namespace channels
} // namespace folly
50 changes: 26 additions & 24 deletions folly/experimental/channels/Transform-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,16 @@ class TransformProcessorBase : public IChannelCallback {

template <typename ReceiverType>
void startTransform(ReceiverType receiver) {
executeWhenReady(
[=, receiver = std::move(receiver)](RateLimiter::Token token) mutable {
runOperationWithSenderCancellation(
transformer_.getExecutor(),
this->sender_,
false /* alreadyStartedWaiting */,
this /* channelCallbackToRestore */,
startTransformImpl(std::move(receiver)),
std::move(token));
});
executeWhenReady([=, receiver = std::move(receiver)](
std::unique_ptr<RateLimiter::Token> token) mutable {
runOperationWithSenderCancellation(
transformer_.getExecutor(),
this->sender_,
false /* alreadyStartedWaiting */,
this /* channelCallbackToRestore */,
startTransformImpl(std::move(receiver)),
std::move(token));
});
}

protected:
Expand All @@ -98,7 +98,7 @@ class TransformProcessorBase : public IChannelCallback {
* sender).
*/
void consume(ChannelBridgeBase* bridge) override {
executeWhenReady([=](RateLimiter::Token token) {
executeWhenReady([=](std::unique_ptr<RateLimiter::Token> token) {
if (bridge == receiver_.get()) {
// We have received new values from the input receiver.
CHECK_NE(getReceiverState(), ChannelState::CancellationProcessed);
Expand All @@ -125,7 +125,7 @@ class TransformProcessorBase : public IChannelCallback {
* listening to.
*/
void canceled(ChannelBridgeBase* bridge) override {
executeWhenReady([=](RateLimiter::Token token) {
executeWhenReady([=](std::unique_ptr<RateLimiter::Token> token) {
if (bridge == receiver_.get()) {
// We previously cancelled the input receiver (because the consumer of
// the output receiver stopped consuming). Process the cancellation for
Expand Down Expand Up @@ -274,14 +274,15 @@ class TransformProcessorBase : public IChannelCallback {
return detail::getSenderState(sender_.get());
}

void executeWhenReady(folly::Function<void(RateLimiter::Token)> func) {
void executeWhenReady(
folly::Function<void(std::unique_ptr<RateLimiter::Token>)> func) {
auto rateLimiter = transformer_.getRateLimiter();
if (rateLimiter != nullptr) {
rateLimiter->executeWhenReady(
std::move(func), transformer_.getExecutor());
} else {
transformer_.getExecutor()->add([func = std::move(func)]() mutable {
func(RateLimiter::Token(nullptr));
func(std::unique_ptr<RateLimiter::Token>(nullptr));
});
}
}
Expand Down Expand Up @@ -348,16 +349,17 @@ class ResumableTransformProcessor : public TransformProcessorBase<
using Base::Base;

void initialize(InitializeArg initializeArg) {
this->executeWhenReady([=, initializeArg = std::move(initializeArg)](
RateLimiter::Token token) mutable {
runOperationWithSenderCancellation(
this->transformer_.getExecutor(),
this->sender_,
false /* currentlyWaiting */,
this /* channelCallbackToRestore */,
initializeImpl(std::move(initializeArg)),
std::move(token));
});
this->executeWhenReady(
[=, initializeArg = std::move(initializeArg)](
std::unique_ptr<RateLimiter::Token> token) mutable {
runOperationWithSenderCancellation(
this->transformer_.getExecutor(),
this->sender_,
false /* currentlyWaiting */,
this /* channelCallbackToRestore */,
initializeImpl(std::move(initializeArg)),
std::move(token));
});
}

private:
Expand Down
2 changes: 1 addition & 1 deletion folly/experimental/channels/detail/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ void runOperationWithSenderCancellation(
bool alreadyStartedWaiting,
IChannelCallback* channelCallbackToRestore,
folly::coro::Task<void> operation,
RateLimiter::Token token) noexcept {
std::unique_ptr<RateLimiter::Token> token) noexcept {
if (alreadyStartedWaiting && (!sender || !sender->cancelSenderWait())) {
// The output receiver was cancelled before starting this operation
// (indicating that the channel callback already ran).
Expand Down
Loading

0 comments on commit ced26fe

Please sign in to comment.