diff --git a/src/ray/gcs/actor_state_accessor_test.cc b/src/ray/gcs/actor_state_accessor_test.cc index 7f3bb8cf1f1a..c6b3d45c0d5a 100644 --- a/src/ray/gcs/actor_state_accessor_test.cc +++ b/src/ray/gcs/actor_state_accessor_test.cc @@ -44,7 +44,7 @@ class ActorStateAccessorTest : public ::testing::Test { void GenTestData() { GenActorData(); } void GenActorData() { - for (size_t i = 0; i < 2; ++i) { + for (size_t i = 0; i < 100; ++i) { std::shared_ptr actor = std::make_shared(); actor->set_max_reconstructions(1); actor->set_remaining_reconstructions(1); diff --git a/src/ray/gcs/asio.cc b/src/ray/gcs/asio.cc index a3d564683842..505a3c51c6f8 100644 --- a/src/ray/gcs/asio.cc +++ b/src/ray/gcs/asio.cc @@ -3,13 +3,15 @@ #include "ray/util/logging.h" RedisAsioClient::RedisAsioClient(boost::asio::io_service &io_service, - redisAsyncContext *async_context) - : async_context_(async_context), + ray::gcs::RedisAsyncContext &redis_async_context) + : redis_async_context_(redis_async_context), socket_(io_service), read_requested_(false), write_requested_(false), read_in_progress_(false), write_in_progress_(false) { + redisAsyncContext *async_context = redis_async_context_.GetRawRedisAsyncContext(); + // gives access to c->fd redisContext *c = &(async_context->c); @@ -47,7 +49,7 @@ void RedisAsioClient::operate() { void RedisAsioClient::handle_read(boost::system::error_code error_code) { RAY_CHECK(!error_code || error_code == boost::asio::error::would_block); read_in_progress_ = false; - redisAsyncHandleRead(async_context_); + redis_async_context_.RedisAsyncHandleRead(); if (error_code == boost::asio::error::would_block) { operate(); @@ -57,7 +59,7 @@ void RedisAsioClient::handle_read(boost::system::error_code error_code) { void RedisAsioClient::handle_write(boost::system::error_code error_code) { RAY_CHECK(!error_code || error_code == boost::asio::error::would_block); write_in_progress_ = false; - redisAsyncHandleWrite(async_context_); + redis_async_context_.RedisAsyncHandleWrite(); if (error_code == boost::asio::error::would_block) { operate(); diff --git a/src/ray/gcs/asio.h b/src/ray/gcs/asio.h index bbb66a9814aa..4056281938c4 100644 --- a/src/ray/gcs/asio.h +++ b/src/ray/gcs/asio.h @@ -29,12 +29,14 @@ #include #include +#include "ray/gcs/redis_async_context.h" #include "ray/thirdparty/hiredis/async.h" #include "ray/thirdparty/hiredis/hiredis.h" class RedisAsioClient { public: - RedisAsioClient(boost::asio::io_service &io_service, redisAsyncContext *ac); + RedisAsioClient(boost::asio::io_service &io_service, + ray::gcs::RedisAsyncContext &redis_async_context); void operate(); @@ -47,7 +49,8 @@ class RedisAsioClient { void cleanup(); private: - redisAsyncContext *async_context_; + ray::gcs::RedisAsyncContext &redis_async_context_; + boost::asio::ip::tcp::socket socket_; // Hiredis wanted to add a read operation to the event loop // but the read might not have happened yet diff --git a/src/ray/gcs/asio_test.cc b/src/ray/gcs/asio_test.cc index 3901766e2b28..7d7828d4d3fb 100644 --- a/src/ray/gcs/asio_test.cc +++ b/src/ray/gcs/asio_test.cc @@ -23,15 +23,15 @@ void GetCallback(redisAsyncContext *c, void *r, void *privdata) { redisReply *reply = reinterpret_cast(r); ASSERT_TRUE(reply != nullptr); ASSERT_TRUE(std::string(reinterpret_cast(reply->str)) == "test"); - redisAsyncDisconnect(c); io_service.stop(); } TEST(RedisAsioTest, TestRedisCommands) { redisAsyncContext *ac = redisAsyncConnect("127.0.0.1", 6379); ASSERT_TRUE(ac->err == 0); + ray::gcs::RedisAsyncContext redis_async_context(ac); - RedisAsioClient client(io_service, ac); + RedisAsioClient client(io_service, redis_async_context); redisAsyncSetConnectCallback(ac, ConnectCallback); redisAsyncSetDisconnectCallback(ac, DisconnectCallback); diff --git a/src/ray/gcs/redis_async_context.cc b/src/ray/gcs/redis_async_context.cc new file mode 100644 index 000000000000..51bb159da035 --- /dev/null +++ b/src/ray/gcs/redis_async_context.cc @@ -0,0 +1,87 @@ +#include "ray/gcs/redis_async_context.h" + +extern "C" { +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" +} + +namespace ray { + +namespace gcs { + +RedisAsyncContext::RedisAsyncContext(redisAsyncContext *redis_async_context) + : redis_async_context_(redis_async_context) { + RAY_CHECK(redis_async_context_ != nullptr); +} + +RedisAsyncContext::~RedisAsyncContext() { + if (redis_async_context_ != nullptr) { + redisAsyncFree(redis_async_context_); + redis_async_context_ = nullptr; + } +} + +redisAsyncContext *RedisAsyncContext::GetRawRedisAsyncContext() { + return redis_async_context_; +} + +void RedisAsyncContext::ResetRawRedisAsyncContext() { + // Reset redis_async_context_ to nullptr because hiredis has released this context. + redis_async_context_ = nullptr; +} + +void RedisAsyncContext::RedisAsyncHandleRead() { + // `redisAsyncHandleRead` is already thread-safe, so no lock here. + redisAsyncHandleRead(redis_async_context_); +} + +void RedisAsyncContext::RedisAsyncHandleWrite() { + // `redisAsyncHandleWrite` will mutate `redis_async_context_`, use a lock to protect + // it. + std::lock_guard lock(mutex_); + redisAsyncHandleWrite(redis_async_context_); +} + +Status RedisAsyncContext::RedisAsyncCommand(redisCallbackFn *fn, void *privdata, + const char *format, ...) { + va_list ap; + va_start(ap, format); + + int ret_code = 0; + { + // `redisvAsyncCommand` will mutate `redis_async_context_`, use a lock to protect it. + std::lock_guard lock(mutex_); + ret_code = redisvAsyncCommand(redis_async_context_, fn, privdata, format, ap); + } + + va_end(ap); + + if (ret_code == REDIS_ERR) { + return Status::RedisError(std::string(redis_async_context_->errstr)); + } + RAY_CHECK(ret_code == REDIS_OK); + return Status::OK(); +} + +Status RedisAsyncContext::RedisAsyncCommandArgv(redisCallbackFn *fn, void *privdata, + int argc, const char **argv, + const size_t *argvlen) { + int ret_code = 0; + { + // `redisAsyncCommandArgv` will mutate `redis_async_context_`, use a lock to protect + // it. + std::lock_guard lock(mutex_); + ret_code = + redisAsyncCommandArgv(redis_async_context_, fn, privdata, argc, argv, argvlen); + } + + if (ret_code == REDIS_ERR) { + return Status::RedisError(std::string(redis_async_context_->errstr)); + } + RAY_CHECK(ret_code == REDIS_OK); + return Status::OK(); +} + +} // namespace gcs + +} // namespace ray diff --git a/src/ray/gcs/redis_async_context.h b/src/ray/gcs/redis_async_context.h new file mode 100644 index 000000000000..c52c330b8666 --- /dev/null +++ b/src/ray/gcs/redis_async_context.h @@ -0,0 +1,73 @@ +#ifndef RAY_GCS_REDIS_ASYNC_CONTEXT_H +#define RAY_GCS_REDIS_ASYNC_CONTEXT_H + +#include +#include +#include "ray/common/status.h" + +extern "C" { +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" +} + +namespace ray { + +namespace gcs { + +/// \class RedisAsyncContext +/// RedisAsyncContext class is a wrapper of hiredis `asyncRedisContext`, providing +/// C++ style and thread-safe API. +class RedisAsyncContext { + public: + explicit RedisAsyncContext(redisAsyncContext *redis_async_context); + + ~RedisAsyncContext(); + + /// Get the raw 'redisAsyncContext' pointer. + /// + /// \return redisAsyncContext * + redisAsyncContext *GetRawRedisAsyncContext(); + + /// Reset the raw 'redisAsyncContext' pointer to nullptr. + void ResetRawRedisAsyncContext(); + + /// Perform command 'redisAsyncHandleRead'. Thread-safe. + void RedisAsyncHandleRead(); + + /// Perform command 'redisAsyncHandleWrite'. Thread-safe. + void RedisAsyncHandleWrite(); + + /// Perform command 'redisvAsyncCommand'. Thread-safe. + /// + /// \param fn Callback that will be called after the command finishes. + /// \param privdata User-defined pointer. + /// \param format Command format. + /// \param ... Command list. + /// \return Status + Status RedisAsyncCommand(redisCallbackFn *fn, void *privdata, const char *format, ...); + + /// Perform command 'redisAsyncCommandArgv'. Thread-safe. + /// + /// \param fn Callback that will be called after the command finishes. + /// \param privdata User-defined pointer. + /// \param argc Number of arguments. + /// \param argv Array with arguments. + /// \param argvlen Array with each argument's length. + /// \return Status + Status RedisAsyncCommandArgv(redisCallbackFn *fn, void *privdata, int argc, + const char **argv, const size_t *argvlen); + + private: + /// This mutex is used to protect `redis_async_context`. + /// NOTE(micafan): All the `redisAsyncContext`-related functions only manipulate memory + /// data and don't actually do any IO operations. So the perf impact of adding the lock + /// should be minimum. + std::mutex mutex_; + redisAsyncContext *redis_async_context_{nullptr}; +}; + +} // namespace gcs + +} // namespace ray + +#endif // RAY_GCS_REDIS_ASYNC_CONTEXT_H diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index e45395a58966..1c339099c4e1 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -135,17 +135,21 @@ void GlobalRedisCallback(void *c, void *r, void *privdata) { int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription) { auto start_time = current_sys_time_us(); + + std::lock_guard lock(mutex_); callback_items_.emplace(num_callbacks_, CallbackItem(function, is_subscription, start_time)); return num_callbacks_++; } RedisCallbackManager::CallbackItem &RedisCallbackManager::get(int64_t callback_index) { + std::lock_guard lock(mutex_); RAY_CHECK(callback_items_.find(callback_index) != callback_items_.end()); return callback_items_[callback_index]; } void RedisCallbackManager::remove(int64_t callback_index) { + std::lock_guard lock(mutex_); callback_items_.erase(callback_index); } @@ -158,12 +162,6 @@ RedisContext::~RedisContext() { if (context_) { redisFree(context_); } - if (async_context_) { - redisAsyncFree(async_context_); - } - if (subscribe_context_) { - redisAsyncFree(subscribe_context_); - } } Status AuthenticateRedis(redisContext *context, const std::string &password) { @@ -190,13 +188,16 @@ Status AuthenticateRedis(redisAsyncContext *context, const std::string &password void RedisAsyncContextDisconnectCallback(const redisAsyncContext *context, int status) { RAY_LOG(WARNING) << "Redis async context disconnected. Status: " << status; - reinterpret_cast(context->data) - ->AsyncDisconnectCallback(context, status); + // Reset raw 'redisAsyncContext' to nullptr because hiredis will release this context. + reinterpret_cast(context->data)->ResetRawRedisAsyncContext(); } -void SetDisconnectCallback(RedisContext *redis_context, redisAsyncContext *context) { - context->data = redis_context; - redisAsyncSetDisconnectCallback(context, RedisAsyncContextDisconnectCallback); +void SetDisconnectCallback(RedisAsyncContext *redis_async_context) { + redisAsyncContext *raw_redis_async_context = + redis_async_context->GetRawRedisAsyncContext(); + raw_redis_async_context->data = redis_async_context; + redisAsyncSetDisconnectCallback(raw_redis_async_context, + RedisAsyncContextDisconnectCallback); } template @@ -228,8 +229,8 @@ Status ConnectWithRetries(const std::string &address, int port, Status RedisContext::Connect(const std::string &address, int port, bool sharding, const std::string &password = "") { RAY_CHECK(!context_); - RAY_CHECK(!async_context_); - RAY_CHECK(!subscribe_context_); + RAY_CHECK(!redis_async_context_); + RAY_CHECK(!async_redis_subscribe_context_); RAY_CHECK_OK(ConnectWithRetries(address, port, redisConnect, &context_)); RAY_CHECK_OK(AuthenticateRedis(context_, password)); @@ -240,29 +241,24 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding freeReplyObject(reply); // Connect to async context - RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &async_context_)); - SetDisconnectCallback(this, async_context_); - RAY_CHECK_OK(AuthenticateRedis(async_context_, password)); + redisAsyncContext *async_context = nullptr; + RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &async_context)); + RAY_CHECK_OK(AuthenticateRedis(async_context, password)); + redis_async_context_.reset(new RedisAsyncContext(async_context)); + SetDisconnectCallback(redis_async_context_.get()); // Connect to subscribe context - RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &subscribe_context_)); - SetDisconnectCallback(this, subscribe_context_); - RAY_CHECK_OK(AuthenticateRedis(subscribe_context_, password)); + redisAsyncContext *subscribe_context = nullptr; + RAY_CHECK_OK(ConnectWithRetries(address, port, redisAsyncConnect, &subscribe_context)); + RAY_CHECK_OK(AuthenticateRedis(subscribe_context, password)); + async_redis_subscribe_context_.reset(new RedisAsyncContext(subscribe_context)); + SetDisconnectCallback(async_redis_subscribe_context_.get()); return Status::OK(); } -Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { - if (redisAeAttach(loop, async_context_) != REDIS_OK || - redisAeAttach(loop, subscribe_context_) != REDIS_OK) { - return Status::RedisError("could not attach redis event loop"); - } else { - return Status::OK(); - } -} - Status RedisContext::RunArgvAsync(const std::vector &args) { - RAY_CHECK(async_context_); + RAY_CHECK(redis_async_context_); // Build the arguments. std::vector argv; std::vector argc; @@ -271,13 +267,9 @@ Status RedisContext::RunArgvAsync(const std::vector &args) { argc.push_back(args[i].size()); } // Run the Redis command. - int status; - status = redisAsyncCommandArgv(async_context_, nullptr, nullptr, args.size(), - argv.data(), argc.data()); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - return Status::OK(); + Status status = redis_async_context_->RedisAsyncCommandArgv( + nullptr, nullptr, args.size(), argv.data(), argc.data()); + return status; } Status RedisContext::SubscribeAsync(const ClientID &client_id, @@ -286,40 +278,28 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, int64_t *out_callback_index) { RAY_CHECK(pubsub_channel != TablePubsub::NO_PUBLISH) << "Client requested subscribe on a table that does not support pubsub"; - RAY_CHECK(subscribe_context_); + RAY_CHECK(async_redis_subscribe_context_); int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, true); RAY_CHECK(out_callback_index != nullptr); *out_callback_index = callback_index; - int status = 0; + Status status = Status::OK(); if (client_id.IsNil()) { // Subscribe to all messages. std::string redis_command = "SUBSCRIBE %d"; - status = redisAsyncCommand( - subscribe_context_, reinterpret_cast(&GlobalRedisCallback), + status = async_redis_subscribe_context_->RedisAsyncCommand( + reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel); } else { // Subscribe only to messages sent to this client. std::string redis_command = "SUBSCRIBE %d:%b"; - status = redisAsyncCommand( - subscribe_context_, reinterpret_cast(&GlobalRedisCallback), + status = async_redis_subscribe_context_->RedisAsyncCommand( + reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel, client_id.Data(), client_id.Size()); } - if (status == REDIS_ERR) { - return Status::RedisError(std::string(subscribe_context_->errstr)); - } - return Status::OK(); -} - -void RedisContext::AsyncDisconnectCallback(const redisAsyncContext *context, int status) { - if (context == async_context_) { - async_context_ = nullptr; - } - if (context == subscribe_context_) { - subscribe_context_ = nullptr; - } + return status; } } // namespace gcs diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 0c81bc7f802e..68df76411ee2 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -3,12 +3,14 @@ #include #include +#include #include #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/util/logging.h" +#include "ray/gcs/redis_async_context.h" #include "ray/protobuf/gcs.pb.h" extern "C" { @@ -101,18 +103,20 @@ class RedisCallbackManager { ~RedisCallbackManager() {} + std::mutex mutex_; + int64_t num_callbacks_ = 0; std::unordered_map callback_items_; }; class RedisContext { public: - RedisContext() - : context_(nullptr), async_context_(nullptr), subscribe_context_(nullptr) {} + RedisContext() : context_(nullptr) {} + ~RedisContext(); + Status Connect(const std::string &address, int port, bool sharding, const std::string &password); - Status AttachToEventLoop(aeEventLoop *loop); /// Run an operation on some table key. /// @@ -150,29 +154,25 @@ class RedisContext { Status SubscribeAsync(const ClientID &client_id, const TablePubsub pubsub_channel, const RedisCallback &redisCallback, int64_t *out_callback_index); - /// Called when an instance of redisAsyncContext is disconnected. - /// - /// \param context the redisAsyncContext instances - /// \param status The status code of disconnection - void AsyncDisconnectCallback(const redisAsyncContext *context, int status); - redisContext *sync_context() { RAY_CHECK(context_); return context_; } - redisAsyncContext *async_context() { - RAY_CHECK(async_context_); - return async_context_; + + RedisAsyncContext &async_context() { + RAY_CHECK(redis_async_context_); + return *redis_async_context_; + } + + RedisAsyncContext &subscribe_context() { + RAY_CHECK(async_redis_subscribe_context_); + return *async_redis_subscribe_context_; } - redisAsyncContext *subscribe_context() { - RAY_CHECK(subscribe_context_); - return subscribe_context_; - }; private: redisContext *context_; - redisAsyncContext *async_context_; - redisAsyncContext *subscribe_context_; + std::unique_ptr redis_async_context_; + std::unique_ptr async_redis_subscribe_context_; }; template @@ -180,40 +180,32 @@ Status RedisContext::RunAsync(const std::string &command, const ID &id, const vo size_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length) { - RAY_CHECK(async_context_); + RAY_CHECK(redis_async_context_); int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); + Status status = Status::OK(); if (length > 0) { if (log_length >= 0) { std::string redis_command = command + " %d %d %b %b %d"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), + status = redis_async_context_->RedisAsyncCommand( + reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), prefix, pubsub_channel, id.Data(), id.Size(), data, length, log_length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } } else { std::string redis_command = command + " %d %d %b %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), + status = redis_async_context_->RedisAsyncCommand( + reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), prefix, pubsub_channel, id.Data(), id.Size(), data, length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } } } else { RAY_CHECK(log_length == -1); std::string redis_command = command + " %d %d %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), + status = redis_async_context_->RedisAsyncCommand( + reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), prefix, pubsub_channel, id.Data(), id.Size()); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } } - return Status::OK(); + return status; } } // namespace gcs