Skip to content

Commit

Permalink
Make GCS Client thread-safe. (ray-project#5413)
Browse files Browse the repository at this point in the history
  • Loading branch information
micafan authored and raulchen committed Aug 17, 2019
1 parent bb31620 commit 47aa2b1
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 99 deletions.
2 changes: 1 addition & 1 deletion src/ray/gcs/actor_state_accessor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ActorStateAccessorTest : public ::testing::Test {
void GenTestData() { GenActorData(); }

void GenActorData() {
for (size_t i = 0; i < 2; ++i) {
for (size_t i = 0; i < 100; ++i) {
std::shared_ptr<ActorTableData> actor = std::make_shared<ActorTableData>();
actor->set_max_reconstructions(1);
actor->set_remaining_reconstructions(1);
Expand Down
10 changes: 6 additions & 4 deletions src/ray/gcs/asio.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
#include "ray/util/logging.h"

RedisAsioClient::RedisAsioClient(boost::asio::io_service &io_service,
redisAsyncContext *async_context)
: async_context_(async_context),
ray::gcs::RedisAsyncContext &redis_async_context)
: redis_async_context_(redis_async_context),
socket_(io_service),
read_requested_(false),
write_requested_(false),
read_in_progress_(false),
write_in_progress_(false) {
redisAsyncContext *async_context = redis_async_context_.GetRawRedisAsyncContext();

// gives access to c->fd
redisContext *c = &(async_context->c);

Expand Down Expand Up @@ -47,7 +49,7 @@ void RedisAsioClient::operate() {
void RedisAsioClient::handle_read(boost::system::error_code error_code) {
RAY_CHECK(!error_code || error_code == boost::asio::error::would_block);
read_in_progress_ = false;
redisAsyncHandleRead(async_context_);
redis_async_context_.RedisAsyncHandleRead();

if (error_code == boost::asio::error::would_block) {
operate();
Expand All @@ -57,7 +59,7 @@ void RedisAsioClient::handle_read(boost::system::error_code error_code) {
void RedisAsioClient::handle_write(boost::system::error_code error_code) {
RAY_CHECK(!error_code || error_code == boost::asio::error::would_block);
write_in_progress_ = false;
redisAsyncHandleWrite(async_context_);
redis_async_context_.RedisAsyncHandleWrite();

if (error_code == boost::asio::error::would_block) {
operate();
Expand Down
7 changes: 5 additions & 2 deletions src/ray/gcs/asio.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@
#include <boost/asio/error.hpp>
#include <boost/bind.hpp>

#include "ray/gcs/redis_async_context.h"
#include "ray/thirdparty/hiredis/async.h"
#include "ray/thirdparty/hiredis/hiredis.h"

class RedisAsioClient {
public:
RedisAsioClient(boost::asio::io_service &io_service, redisAsyncContext *ac);
RedisAsioClient(boost::asio::io_service &io_service,
ray::gcs::RedisAsyncContext &redis_async_context);

void operate();

Expand All @@ -47,7 +49,8 @@ class RedisAsioClient {
void cleanup();

private:
redisAsyncContext *async_context_;
ray::gcs::RedisAsyncContext &redis_async_context_;

boost::asio::ip::tcp::socket socket_;
// Hiredis wanted to add a read operation to the event loop
// but the read might not have happened yet
Expand Down
4 changes: 2 additions & 2 deletions src/ray/gcs/asio_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ void GetCallback(redisAsyncContext *c, void *r, void *privdata) {
redisReply *reply = reinterpret_cast<redisReply *>(r);
ASSERT_TRUE(reply != nullptr);
ASSERT_TRUE(std::string(reinterpret_cast<char *>(reply->str)) == "test");
redisAsyncDisconnect(c);
io_service.stop();
}

TEST(RedisAsioTest, TestRedisCommands) {
redisAsyncContext *ac = redisAsyncConnect("127.0.0.1", 6379);
ASSERT_TRUE(ac->err == 0);
ray::gcs::RedisAsyncContext redis_async_context(ac);

RedisAsioClient client(io_service, ac);
RedisAsioClient client(io_service, redis_async_context);

redisAsyncSetConnectCallback(ac, ConnectCallback);
redisAsyncSetDisconnectCallback(ac, DisconnectCallback);
Expand Down
87 changes: 87 additions & 0 deletions src/ray/gcs/redis_async_context.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#include "ray/gcs/redis_async_context.h"

extern "C" {
#include "ray/thirdparty/hiredis/async.h"
#include "ray/thirdparty/hiredis/hiredis.h"
}

namespace ray {

namespace gcs {

RedisAsyncContext::RedisAsyncContext(redisAsyncContext *redis_async_context)
: redis_async_context_(redis_async_context) {
RAY_CHECK(redis_async_context_ != nullptr);
}

RedisAsyncContext::~RedisAsyncContext() {
if (redis_async_context_ != nullptr) {
redisAsyncFree(redis_async_context_);
redis_async_context_ = nullptr;
}
}

redisAsyncContext *RedisAsyncContext::GetRawRedisAsyncContext() {
return redis_async_context_;
}

void RedisAsyncContext::ResetRawRedisAsyncContext() {
// Reset redis_async_context_ to nullptr because hiredis has released this context.
redis_async_context_ = nullptr;
}

void RedisAsyncContext::RedisAsyncHandleRead() {
// `redisAsyncHandleRead` is already thread-safe, so no lock here.
redisAsyncHandleRead(redis_async_context_);
}

void RedisAsyncContext::RedisAsyncHandleWrite() {
// `redisAsyncHandleWrite` will mutate `redis_async_context_`, use a lock to protect
// it.
std::lock_guard<std::mutex> lock(mutex_);
redisAsyncHandleWrite(redis_async_context_);
}

Status RedisAsyncContext::RedisAsyncCommand(redisCallbackFn *fn, void *privdata,
const char *format, ...) {
va_list ap;
va_start(ap, format);

int ret_code = 0;
{
// `redisvAsyncCommand` will mutate `redis_async_context_`, use a lock to protect it.
std::lock_guard<std::mutex> lock(mutex_);
ret_code = redisvAsyncCommand(redis_async_context_, fn, privdata, format, ap);
}

va_end(ap);

if (ret_code == REDIS_ERR) {
return Status::RedisError(std::string(redis_async_context_->errstr));
}
RAY_CHECK(ret_code == REDIS_OK);
return Status::OK();
}

Status RedisAsyncContext::RedisAsyncCommandArgv(redisCallbackFn *fn, void *privdata,
int argc, const char **argv,
const size_t *argvlen) {
int ret_code = 0;
{
// `redisAsyncCommandArgv` will mutate `redis_async_context_`, use a lock to protect
// it.
std::lock_guard<std::mutex> lock(mutex_);
ret_code =
redisAsyncCommandArgv(redis_async_context_, fn, privdata, argc, argv, argvlen);
}

if (ret_code == REDIS_ERR) {
return Status::RedisError(std::string(redis_async_context_->errstr));
}
RAY_CHECK(ret_code == REDIS_OK);
return Status::OK();
}

} // namespace gcs

} // namespace ray
73 changes: 73 additions & 0 deletions src/ray/gcs/redis_async_context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#ifndef RAY_GCS_REDIS_ASYNC_CONTEXT_H
#define RAY_GCS_REDIS_ASYNC_CONTEXT_H

#include <stdarg.h>
#include <mutex>
#include "ray/common/status.h"

extern "C" {
#include "ray/thirdparty/hiredis/async.h"
#include "ray/thirdparty/hiredis/hiredis.h"
}

namespace ray {

namespace gcs {

/// \class RedisAsyncContext
/// RedisAsyncContext class is a wrapper of hiredis `asyncRedisContext`, providing
/// C++ style and thread-safe API.
class RedisAsyncContext {
public:
explicit RedisAsyncContext(redisAsyncContext *redis_async_context);

~RedisAsyncContext();

/// Get the raw 'redisAsyncContext' pointer.
///
/// \return redisAsyncContext *
redisAsyncContext *GetRawRedisAsyncContext();

/// Reset the raw 'redisAsyncContext' pointer to nullptr.
void ResetRawRedisAsyncContext();

/// Perform command 'redisAsyncHandleRead'. Thread-safe.
void RedisAsyncHandleRead();

/// Perform command 'redisAsyncHandleWrite'. Thread-safe.
void RedisAsyncHandleWrite();

/// Perform command 'redisvAsyncCommand'. Thread-safe.
///
/// \param fn Callback that will be called after the command finishes.
/// \param privdata User-defined pointer.
/// \param format Command format.
/// \param ... Command list.
/// \return Status
Status RedisAsyncCommand(redisCallbackFn *fn, void *privdata, const char *format, ...);

/// Perform command 'redisAsyncCommandArgv'. Thread-safe.
///
/// \param fn Callback that will be called after the command finishes.
/// \param privdata User-defined pointer.
/// \param argc Number of arguments.
/// \param argv Array with arguments.
/// \param argvlen Array with each argument's length.
/// \return Status
Status RedisAsyncCommandArgv(redisCallbackFn *fn, void *privdata, int argc,
const char **argv, const size_t *argvlen);

private:
/// This mutex is used to protect `redis_async_context`.
/// NOTE(micafan): All the `redisAsyncContext`-related functions only manipulate memory
/// data and don't actually do any IO operations. So the perf impact of adding the lock
/// should be minimum.
std::mutex mutex_;
redisAsyncContext *redis_async_context_{nullptr};
};

} // namespace gcs

} // namespace ray

#endif // RAY_GCS_REDIS_ASYNC_CONTEXT_H
Loading

0 comments on commit 47aa2b1

Please sign in to comment.