Skip to content

Commit

Permalink
Refactor redis callback handling (ray-project#4841)
Browse files Browse the repository at this point in the history
* Add CallbackReply

* Fix

* fix linting by format.sh

* Fix linting

* Address comments.

* Fix
  • Loading branch information
jovany-wang authored May 30, 2019
1 parent 3f4d37c commit b7c284a
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 75 deletions.
140 changes: 79 additions & 61 deletions src/ray/gcs/redis_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace {

/// A helper function to call the callback and delete it from the callback
/// manager if necessary.
void ProcessCallback(int64_t callback_index, const std::string &data) {
void ProcessCallback(int64_t callback_index,
const ray::gcs::CallbackReply &callback_reply) {
RAY_CHECK(callback_index >= 0) << "The callback index must be greater than 0, "
<< "but it actually is " << callback_index;
auto callback_item = ray::gcs::RedisCallbackManager::instance().get(callback_index);
Expand All @@ -31,7 +32,7 @@ void ProcessCallback(int64_t callback_index, const std::string &data) {
}
// Invoke the callback.
if (callback_item.callback != nullptr) {
callback_item.callback(data);
callback_item.callback(callback_reply);
}
if (!callback_item.is_subscription) {
// Delete the callback if it's not a subscription callback.
Expand All @@ -45,74 +46,91 @@ namespace ray {

namespace gcs {

// This is a global redis callback which will be registered for every
// asynchronous redis call. It dispatches the appropriate callback
// that was registered with the RedisCallbackManager.
void GlobalRedisCallback(void *c, void *r, void *privdata) {
if (r == nullptr) {
return;
CallbackReply::CallbackReply(redisReply *redis_reply) {
RAY_CHECK(nullptr != redis_reply);
RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) << "Got an error in redis reply: "
<< redis_reply->str;
this->redis_reply_ = redis_reply;
}

bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == redis_reply_->type; }

int64_t CallbackReply::ReadAsInteger() const {
RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) << "Unexpected type: "
<< redis_reply_->type;
return static_cast<int64_t>(redis_reply_->integer);
}

std::string CallbackReply::ReadAsString() const {
RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) << "Unexpected type: "
<< redis_reply_->type;
return std::string(redis_reply_->str, redis_reply_->len);
}

Status CallbackReply::ReadAsStatus() const {
RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) << "Unexpected type: "
<< redis_reply_->type;
const std::string status_str(redis_reply_->str, redis_reply_->len);
if ("OK" == status_str) {
return Status::OK();
}
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
redisReply *reply = reinterpret_cast<redisReply *>(r);

return Status::RedisError(status_str);
}

std::string CallbackReply::ReadAsPubsubData() const {
RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) << "Unexpected type: "
<< redis_reply_->type;

std::string data = "";
// Parse the response.
switch (reply->type) {
case (REDIS_REPLY_NIL): {
// Do not add any data for a nil response.
} break;
case (REDIS_REPLY_STRING): {
data = std::string(reply->str, reply->len);
} break;
case (REDIS_REPLY_STATUS): {
} break;
case (REDIS_REPLY_ERROR): {
RAY_LOG(FATAL) << "Redis error: " << reply->str;
} break;
case (REDIS_REPLY_INTEGER): {
data = std::to_string(reply->integer);
break;
// Parse the published message.
redisReply *message_type = redis_reply_->element[0];
if (strcmp(message_type->str, "subscribe") == 0) {
// If the message is for the initial subscription call, return the empty
// string as a response to signify that subscription was successful.
} else if (strcmp(message_type->str, "message") == 0) {
// If the message is from a PUBLISH, make sure the data is nonempty.
redisReply *message = redis_reply_->element[redis_reply_->elements - 1];
// data is a notification message.
data = std::string(message->str, message->len);
RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel.";
} else {
RAY_LOG(FATAL) << "This is not a pubsub reply: data=" << message_type->str;
}
default:
RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string "
<< reply->str;

return data;
}

void CallbackReply::ReadAsStringArray(std::vector<std::string> *array) const {
RAY_CHECK(nullptr != array) << "Argument `array` must not be nullptr.";
RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type);

const auto array_size = static_cast<size_t>(redis_reply_->elements);
if (array_size > 0) {
auto *entry = redis_reply_->element[0];
const bool is_pubsub_reply =
strcmp(entry->str, "subscribe") == 0 || strcmp(entry->str, "message") == 0;
RAY_CHECK(!is_pubsub_reply) << "Subpub reply cannot be read as a string array.";
}

array->resize(array_size);
for (size_t i = 0; i < array_size; ++i) {
auto *entry = redis_reply_->element[i];
RAY_CHECK(REDIS_REPLY_STRING == entry->type) << "Unexcepted type: " << entry->type;
array->push_back(std::string(entry->str, entry->len));
}
ProcessCallback(callback_index, data);
}

void SubscribeRedisCallback(void *c, void *r, void *privdata) {
// This is a global redis callback which will be registered for every
// asynchronous redis call. It dispatches the appropriate callback
// that was registered with the RedisCallbackManager.
void GlobalRedisCallback(void *c, void *r, void *privdata) {
if (r == nullptr) {
return;
}
int64_t callback_index = reinterpret_cast<int64_t>(privdata);
redisReply *reply = reinterpret_cast<redisReply *>(r);
std::string data = "";
// Parse the response.
switch (reply->type) {
case (REDIS_REPLY_ARRAY): {
// Parse the published message.
redisReply *message_type = reply->element[0];
if (strcmp(message_type->str, "subscribe") == 0) {
// If the message is for the initial subscription call, return the empty
// string as a response to signify that subscription was successful.
} else if (strcmp(message_type->str, "message") == 0) {
// If the message is from a PUBLISH, make sure the data is nonempty.
redisReply *message = reply->element[reply->elements - 1];
auto notification = std::string(message->str, message->len);
RAY_CHECK(!notification.empty()) << "Empty message received on subscribe channel";
data = notification;
} else {
RAY_LOG(FATAL) << "Fatal redis error during subscribe" << message_type->str;
}

} break;
case (REDIS_REPLY_ERROR): {
RAY_LOG(FATAL) << "Redis error: " << reply->str;
} break;
default:
RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string "
<< reply->str;
}
ProcessCallback(callback_index, data);
ProcessCallback(callback_index, CallbackReply(reply));
}

int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription) {
Expand Down Expand Up @@ -259,13 +277,13 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id,
// Subscribe to all messages.
std::string redis_command = "SUBSCRIBE %d";
status = redisAsyncCommand(
subscribe_context_, reinterpret_cast<redisCallbackFn *>(&SubscribeRedisCallback),
subscribe_context_, reinterpret_cast<redisCallbackFn *>(&GlobalRedisCallback),
reinterpret_cast<void *>(callback_index), redis_command.c_str(), pubsub_channel);
} else {
// Subscribe only to messages sent to this client.
std::string redis_command = "SUBSCRIBE %d:%b";
status = redisAsyncCommand(
subscribe_context_, reinterpret_cast<redisCallbackFn *>(&SubscribeRedisCallback),
subscribe_context_, reinterpret_cast<redisCallbackFn *>(&GlobalRedisCallback),
reinterpret_cast<void *>(callback_index), redis_command.c_str(), pubsub_channel,
client_id.data(), client_id.size());
}
Expand Down
36 changes: 35 additions & 1 deletion src/ray/gcs/redis_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,43 @@ struct aeEventLoop;
namespace ray {

namespace gcs {

/// A simple reply wrapper for redis reply.
class CallbackReply {
public:
explicit CallbackReply(redisReply *redis_reply);

/// Whether this reply is `nil` type reply.
bool IsNil() const;

/// Read this reply data as an integer.
int64_t ReadAsInteger() const;

/// Read this reply data as a string.
///
/// Note that this will return an empty string if
/// the type of this reply is `nil` or `status`.
std::string ReadAsString() const;

/// Read this reply data as a status.
Status ReadAsStatus() const;

/// Read this reply data as a pub-sub data.
std::string ReadAsPubsubData() const;

/// Read this reply data as a string array.
///
/// \param array Since the return-value may be large,
/// make it as an output parameter.
void ReadAsStringArray(std::vector<std::string> *array) const;

private:
redisReply *redis_reply_;
};

/// Every callback should take in a vector of the results from the Redis
/// operation.
using RedisCallback = std::function<void(const std::string &)>;
using RedisCallback = std::function<void(const CallbackReply &)>;

void GlobalRedisCallback(void *c, void *r, void *privdata);

Expand Down
2 changes: 1 addition & 1 deletion src/ray/gcs/redis_module/ray_redis_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ int TableAppend_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc,
// The requested index did not match the current length of the log. Return
// an error message as a string.
static const char *reply = "ERR entry exists";
RedisModule_ReplyWithStringBuffer(ctx, reply, strlen(reply));
RedisModule_ReplyWithSimpleString(ctx, reply);
return REDISMODULE_ERR;
}
}
Expand Down
29 changes: 17 additions & 12 deletions src/ray/gcs/tables.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ template <typename ID, typename Data>
Status Log<ID, Data>::Append(const DriverID &driver_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
num_appends_++;
auto callback = [this, id, dataT, done](const std::string &data) {
// If data is not empty, then Redis failed to append the entry.
RAY_CHECK(data.empty()) << "TABLE_APPEND command failed: " << data;

auto callback = [this, id, dataT, done](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
// Failed to append the entry.
RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:"
<< status.ToString();
if (done != nullptr) {
(done)(client_, id, *dataT);
}
Expand All @@ -62,8 +63,9 @@ Status Log<ID, Data>::AppendAt(const DriverID &driver_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done,
const WriteCallback &failure, int log_length) {
num_appends_++;
auto callback = [this, id, dataT, done, failure](const std::string &data) {
if (data.empty()) {
auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) {
const auto status = reply.ReadAsStatus();
if (status.ok()) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
Expand All @@ -85,10 +87,11 @@ template <typename ID, typename Data>
Status Log<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
const Callback &lookup) {
num_lookups_++;
auto callback = [this, id, lookup](const std::string &data) {
auto callback = [this, id, lookup](const CallbackReply &reply) {
if (lookup != nullptr) {
std::vector<DataT> results;
if (!data.empty()) {
if (!reply.IsNil()) {
const auto data = reply.ReadAsString();
auto root = flatbuffers::GetRoot<GcsTableEntry>(data.data());
RAY_CHECK(from_flatbuf<ID>(*root->id()) == id);
for (size_t i = 0; i < root->entries()->size(); i++) {
Expand Down Expand Up @@ -125,7 +128,9 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
const SubscriptionCallback &done) {
RAY_CHECK(subscribe_callback_index_ == -1)
<< "Client called Subscribe twice on the same table";
auto callback = [this, subscribe, done](const std::string &data) {
auto callback = [this, subscribe, done](const CallbackReply &reply) {
const auto data = reply.ReadAsPubsubData();

if (data.empty()) {
// No notification data is provided. This is the callback for the
// initial subscription request.
Expand Down Expand Up @@ -231,7 +236,7 @@ template <typename ID, typename Data>
Status Table<ID, Data>::Add(const DriverID &driver_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, dataT, done](const std::string &data) {
auto callback = [this, id, dataT, done](const CallbackReply &reply) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
Expand Down Expand Up @@ -296,7 +301,7 @@ template <typename ID, typename Data>
Status Set<ID, Data>::Add(const DriverID &driver_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
num_adds_++;
auto callback = [this, id, dataT, done](const std::string &data) {
auto callback = [this, id, dataT, done](const CallbackReply &reply) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
Expand All @@ -313,7 +318,7 @@ template <typename ID, typename Data>
Status Set<ID, Data>::Remove(const DriverID &driver_id, const ID &id,
std::shared_ptr<DataT> &dataT, const WriteCallback &done) {
num_removes_++;
auto callback = [this, id, dataT, done](const std::string &data) {
auto callback = [this, id, dataT, done](const CallbackReply &reply) {
if (done != nullptr) {
(done)(client_, id, *dataT);
}
Expand Down

0 comments on commit b7c284a

Please sign in to comment.