From b7c284aaa3c4869327c2e80e11b2a8ab36e0d8bd Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Thu, 30 May 2019 11:54:30 +0800 Subject: [PATCH] Refactor redis callback handling (#4841) * Add CallbackReply * Fix * fix linting by format.sh * Fix linting * Address comments. * Fix --- src/ray/gcs/redis_context.cc | 140 +++++++++++-------- src/ray/gcs/redis_context.h | 36 ++++- src/ray/gcs/redis_module/ray_redis_module.cc | 2 +- src/ray/gcs/tables.cc | 29 ++-- 4 files changed, 132 insertions(+), 75 deletions(-) diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index 6b03fa735007..e0c5a6565412 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -20,7 +20,8 @@ namespace { /// A helper function to call the callback and delete it from the callback /// manager if necessary. -void ProcessCallback(int64_t callback_index, const std::string &data) { +void ProcessCallback(int64_t callback_index, + const ray::gcs::CallbackReply &callback_reply) { RAY_CHECK(callback_index >= 0) << "The callback index must be greater than 0, " << "but it actually is " << callback_index; auto callback_item = ray::gcs::RedisCallbackManager::instance().get(callback_index); @@ -31,7 +32,7 @@ void ProcessCallback(int64_t callback_index, const std::string &data) { } // Invoke the callback. if (callback_item.callback != nullptr) { - callback_item.callback(data); + callback_item.callback(callback_reply); } if (!callback_item.is_subscription) { // Delete the callback if it's not a subscription callback. @@ -45,74 +46,91 @@ namespace ray { namespace gcs { -// This is a global redis callback which will be registered for every -// asynchronous redis call. It dispatches the appropriate callback -// that was registered with the RedisCallbackManager. -void GlobalRedisCallback(void *c, void *r, void *privdata) { - if (r == nullptr) { - return; +CallbackReply::CallbackReply(redisReply *redis_reply) { + RAY_CHECK(nullptr != redis_reply); + RAY_CHECK(redis_reply->type != REDIS_REPLY_ERROR) << "Got an error in redis reply: " + << redis_reply->str; + this->redis_reply_ = redis_reply; +} + +bool CallbackReply::IsNil() const { return REDIS_REPLY_NIL == redis_reply_->type; } + +int64_t CallbackReply::ReadAsInteger() const { + RAY_CHECK(REDIS_REPLY_INTEGER == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + return static_cast(redis_reply_->integer); +} + +std::string CallbackReply::ReadAsString() const { + RAY_CHECK(REDIS_REPLY_STRING == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + return std::string(redis_reply_->str, redis_reply_->len); +} + +Status CallbackReply::ReadAsStatus() const { + RAY_CHECK(REDIS_REPLY_STATUS == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + const std::string status_str(redis_reply_->str, redis_reply_->len); + if ("OK" == status_str) { + return Status::OK(); } - int64_t callback_index = reinterpret_cast(privdata); - redisReply *reply = reinterpret_cast(r); + + return Status::RedisError(status_str); +} + +std::string CallbackReply::ReadAsPubsubData() const { + RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type) << "Unexpected type: " + << redis_reply_->type; + std::string data = ""; - // Parse the response. - switch (reply->type) { - case (REDIS_REPLY_NIL): { - // Do not add any data for a nil response. - } break; - case (REDIS_REPLY_STRING): { - data = std::string(reply->str, reply->len); - } break; - case (REDIS_REPLY_STATUS): { - } break; - case (REDIS_REPLY_ERROR): { - RAY_LOG(FATAL) << "Redis error: " << reply->str; - } break; - case (REDIS_REPLY_INTEGER): { - data = std::to_string(reply->integer); - break; + // Parse the published message. + redisReply *message_type = redis_reply_->element[0]; + if (strcmp(message_type->str, "subscribe") == 0) { + // If the message is for the initial subscription call, return the empty + // string as a response to signify that subscription was successful. + } else if (strcmp(message_type->str, "message") == 0) { + // If the message is from a PUBLISH, make sure the data is nonempty. + redisReply *message = redis_reply_->element[redis_reply_->elements - 1]; + // data is a notification message. + data = std::string(message->str, message->len); + RAY_CHECK(!data.empty()) << "Empty message received on subscribe channel."; + } else { + RAY_LOG(FATAL) << "This is not a pubsub reply: data=" << message_type->str; } - default: - RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string " - << reply->str; + + return data; +} + +void CallbackReply::ReadAsStringArray(std::vector *array) const { + RAY_CHECK(nullptr != array) << "Argument `array` must not be nullptr."; + RAY_CHECK(REDIS_REPLY_ARRAY == redis_reply_->type); + + const auto array_size = static_cast(redis_reply_->elements); + if (array_size > 0) { + auto *entry = redis_reply_->element[0]; + const bool is_pubsub_reply = + strcmp(entry->str, "subscribe") == 0 || strcmp(entry->str, "message") == 0; + RAY_CHECK(!is_pubsub_reply) << "Subpub reply cannot be read as a string array."; + } + + array->resize(array_size); + for (size_t i = 0; i < array_size; ++i) { + auto *entry = redis_reply_->element[i]; + RAY_CHECK(REDIS_REPLY_STRING == entry->type) << "Unexcepted type: " << entry->type; + array->push_back(std::string(entry->str, entry->len)); } - ProcessCallback(callback_index, data); } -void SubscribeRedisCallback(void *c, void *r, void *privdata) { +// This is a global redis callback which will be registered for every +// asynchronous redis call. It dispatches the appropriate callback +// that was registered with the RedisCallbackManager. +void GlobalRedisCallback(void *c, void *r, void *privdata) { if (r == nullptr) { return; } int64_t callback_index = reinterpret_cast(privdata); redisReply *reply = reinterpret_cast(r); - std::string data = ""; - // Parse the response. - switch (reply->type) { - case (REDIS_REPLY_ARRAY): { - // Parse the published message. - redisReply *message_type = reply->element[0]; - if (strcmp(message_type->str, "subscribe") == 0) { - // If the message is for the initial subscription call, return the empty - // string as a response to signify that subscription was successful. - } else if (strcmp(message_type->str, "message") == 0) { - // If the message is from a PUBLISH, make sure the data is nonempty. - redisReply *message = reply->element[reply->elements - 1]; - auto notification = std::string(message->str, message->len); - RAY_CHECK(!notification.empty()) << "Empty message received on subscribe channel"; - data = notification; - } else { - RAY_LOG(FATAL) << "Fatal redis error during subscribe" << message_type->str; - } - - } break; - case (REDIS_REPLY_ERROR): { - RAY_LOG(FATAL) << "Redis error: " << reply->str; - } break; - default: - RAY_LOG(FATAL) << "Fatal redis error of type " << reply->type << " and with string " - << reply->str; - } - ProcessCallback(callback_index, data); + ProcessCallback(callback_index, CallbackReply(reply)); } int64_t RedisCallbackManager::add(const RedisCallback &function, bool is_subscription) { @@ -259,13 +277,13 @@ Status RedisContext::SubscribeAsync(const ClientID &client_id, // Subscribe to all messages. std::string redis_command = "SUBSCRIBE %d"; status = redisAsyncCommand( - subscribe_context_, reinterpret_cast(&SubscribeRedisCallback), + subscribe_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel); } else { // Subscribe only to messages sent to this client. std::string redis_command = "SUBSCRIBE %d:%b"; status = redisAsyncCommand( - subscribe_context_, reinterpret_cast(&SubscribeRedisCallback), + subscribe_context_, reinterpret_cast(&GlobalRedisCallback), reinterpret_cast(callback_index), redis_command.c_str(), pubsub_channel, client_id.data(), client_id.size()); } diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 93a343464892..b82915374b0a 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -24,9 +24,43 @@ struct aeEventLoop; namespace ray { namespace gcs { + +/// A simple reply wrapper for redis reply. +class CallbackReply { + public: + explicit CallbackReply(redisReply *redis_reply); + + /// Whether this reply is `nil` type reply. + bool IsNil() const; + + /// Read this reply data as an integer. + int64_t ReadAsInteger() const; + + /// Read this reply data as a string. + /// + /// Note that this will return an empty string if + /// the type of this reply is `nil` or `status`. + std::string ReadAsString() const; + + /// Read this reply data as a status. + Status ReadAsStatus() const; + + /// Read this reply data as a pub-sub data. + std::string ReadAsPubsubData() const; + + /// Read this reply data as a string array. + /// + /// \param array Since the return-value may be large, + /// make it as an output parameter. + void ReadAsStringArray(std::vector *array) const; + + private: + redisReply *redis_reply_; +}; + /// Every callback should take in a vector of the results from the Redis /// operation. -using RedisCallback = std::function; +using RedisCallback = std::function; void GlobalRedisCallback(void *c, void *r, void *privdata); diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index b9891e8cae32..0014778896cd 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -351,7 +351,7 @@ int TableAppend_DoWrite(RedisModuleCtx *ctx, RedisModuleString **argv, int argc, // The requested index did not match the current length of the log. Return // an error message as a string. static const char *reply = "ERR entry exists"; - RedisModule_ReplyWithStringBuffer(ctx, reply, strlen(reply)); + RedisModule_ReplyWithSimpleString(ctx, reply); return REDISMODULE_ERR; } } diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index 3d4708940d1a..ccf05f2b5151 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -41,10 +41,11 @@ template Status Log::Append(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_appends_++; - auto callback = [this, id, dataT, done](const std::string &data) { - // If data is not empty, then Redis failed to append the entry. - RAY_CHECK(data.empty()) << "TABLE_APPEND command failed: " << data; - + auto callback = [this, id, dataT, done](const CallbackReply &reply) { + const auto status = reply.ReadAsStatus(); + // Failed to append the entry. + RAY_CHECK(status.ok()) << "Failed to execute command TABLE_APPEND:" + << status.ToString(); if (done != nullptr) { (done)(client_, id, *dataT); } @@ -62,8 +63,9 @@ Status Log::AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; - auto callback = [this, id, dataT, done, failure](const std::string &data) { - if (data.empty()) { + auto callback = [this, id, dataT, done, failure](const CallbackReply &reply) { + const auto status = reply.ReadAsStatus(); + if (status.ok()) { if (done != nullptr) { (done)(client_, id, *dataT); } @@ -85,10 +87,11 @@ template Status Log::Lookup(const DriverID &driver_id, const ID &id, const Callback &lookup) { num_lookups_++; - auto callback = [this, id, lookup](const std::string &data) { + auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { std::vector results; - if (!data.empty()) { + if (!reply.IsNil()) { + const auto data = reply.ReadAsString(); auto root = flatbuffers::GetRoot(data.data()); RAY_CHECK(from_flatbuf(*root->id()) == id); for (size_t i = 0; i < root->entries()->size(); i++) { @@ -125,7 +128,9 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien const SubscriptionCallback &done) { RAY_CHECK(subscribe_callback_index_ == -1) << "Client called Subscribe twice on the same table"; - auto callback = [this, subscribe, done](const std::string &data) { + auto callback = [this, subscribe, done](const CallbackReply &reply) { + const auto data = reply.ReadAsPubsubData(); + if (data.empty()) { // No notification data is provided. This is the callback for the // initial subscription request. @@ -231,7 +236,7 @@ template Status Table::Add(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const std::string &data) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { (done)(client_, id, *dataT); } @@ -296,7 +301,7 @@ template Status Set::Add(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_adds_++; - auto callback = [this, id, dataT, done](const std::string &data) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { (done)(client_, id, *dataT); } @@ -313,7 +318,7 @@ template Status Set::Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &dataT, const WriteCallback &done) { num_removes_++; - auto callback = [this, id, dataT, done](const std::string &data) { + auto callback = [this, id, dataT, done](const CallbackReply &reply) { if (done != nullptr) { (done)(client_, id, *dataT); }