Skip to content

Commit

Permalink
[Pubsub] reduce memory usage for channels that do not require total m…
Browse files Browse the repository at this point in the history
…emory cap (ray-project#23985)

In ray-project@a1e06f6, memory bound was added for each subscribed entity in the publisher. It adds two extra `std::deque` per subscribed entity, which turns out to cost a lot more memory when there are a large number of `ObjectRef`s: ray-project#23853 (comment)

This PR avoids the extra memory usage for entities in channels unlikely to grow too large, i.e. all channels except those for logs and error info. Subscribed entity memory usage no longer shows up in the memory profile when there are 1M object refs. Raw data: [profile006.pb.gz](https://github.com/ray-project/ray/files/8508547/profile006.pb.gz)
  • Loading branch information
mwtian authored Apr 20, 2022
1 parent 2169007 commit 34fb092
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 32 deletions.
3 changes: 2 additions & 1 deletion src/ray/core_worker/test/reference_count_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ using SubscriptionFailureCallbackMap =
// static maps are used to simulate distirubted environment.
static SubscriptionCallbackMap subscription_callback_map;
static SubscriptionFailureCallbackMap subscription_failure_callback_map;
static pubsub::pub_internal::SubscriptionIndex directory;
static pubsub::pub_internal::SubscriptionIndex directory(
rpc::ChannelType::WORKER_OBJECT_LOCATIONS_CHANNEL);

static std::string GenerateID(UniqueID publisher_id, UniqueID subscriber_id) {
return publisher_id.Binary() + subscriber_id.Binary();
Expand Down
58 changes: 44 additions & 14 deletions src/ray/pubsub/publisher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,18 @@ namespace pubsub {

namespace pub_internal {

bool EntityState::Publish(const rpc::PubMessage &pub_message) {
bool BasicEntityState::Publish(const rpc::PubMessage &pub_message) {
if (subscribers_.empty()) {
return false;
}
const auto msg = std::make_shared<rpc::PubMessage>(pub_message);
for (auto &[id, subscriber] : subscribers_) {
subscriber->QueueMessage(msg);
}
return true;
}

bool CappedEntityState::Publish(const rpc::PubMessage &pub_message) {
if (subscribers_.empty()) {
return false;
}
Expand Down Expand Up @@ -90,24 +101,32 @@ const absl::flat_hash_map<SubscriberID, SubscriberState *> &EntityState::Subscri
return subscribers_;
}

SubscriptionIndex::SubscriptionIndex(rpc::ChannelType channel_type)
: channel_type_(channel_type), subscribers_to_all_(CreateEntityState()) {}

bool SubscriptionIndex::Publish(const rpc::PubMessage &pub_message) {
const bool publish_to_all = subscribers_to_all_.Publish(pub_message);
const bool publish_to_all = subscribers_to_all_->Publish(pub_message);
bool publish_to_entity = false;
auto it = entities_.find(pub_message.key_id());
if (it != entities_.end()) {
publish_to_entity = it->second.Publish(pub_message);
publish_to_entity = it->second->Publish(pub_message);
}
return publish_to_all || publish_to_entity;
}

bool SubscriptionIndex::AddEntry(const std::string &key_id, SubscriberState *subscriber) {
if (key_id.empty()) {
return subscribers_to_all_.AddSubscriber(subscriber);
return subscribers_to_all_->AddSubscriber(subscriber);
}

auto &subscribing_key_ids = subscribers_to_key_id_[subscriber->id()];
const bool key_added = subscribing_key_ids.emplace(key_id).second;
const bool subscriber_added = entities_[key_id].AddSubscriber(subscriber);

auto sub_it = entities_.find(key_id);
if (sub_it == entities_.end()) {
sub_it = entities_.emplace(key_id, CreateEntityState()).first;
}
const bool subscriber_added = sub_it->second->AddSubscriber(subscriber);

RAY_CHECK(key_added == subscriber_added);
return key_added;
Expand All @@ -116,14 +135,14 @@ bool SubscriptionIndex::AddEntry(const std::string &key_id, SubscriberState *sub
std::vector<SubscriberID> SubscriptionIndex::GetSubscriberIdsByKeyId(
const std::string &key_id) const {
std::vector<SubscriberID> subscribers;
if (!subscribers_to_all_.Subscribers().empty()) {
for (const auto &[sub_id, sub] : subscribers_to_all_.Subscribers()) {
if (!subscribers_to_all_->Subscribers().empty()) {
for (const auto &[sub_id, sub] : subscribers_to_all_->Subscribers()) {
subscribers.push_back(sub_id);
}
}
auto it = entities_.find(key_id);
if (it != entities_.end()) {
for (const auto &[sub_id, sub] : it->second.Subscribers()) {
for (const auto &[sub_id, sub] : it->second->Subscribers()) {
subscribers.push_back(sub_id);
}
}
Expand All @@ -132,7 +151,7 @@ std::vector<SubscriberID> SubscriptionIndex::GetSubscriberIdsByKeyId(

bool SubscriptionIndex::EraseSubscriber(const SubscriberID &subscriber_id) {
// Erase subscriber of all keys.
if (subscribers_to_all_.RemoveSubscriber(subscriber_id)) {
if (subscribers_to_all_->RemoveSubscriber(subscriber_id)) {
return true;
}

Expand All @@ -149,7 +168,7 @@ bool SubscriptionIndex::EraseSubscriber(const SubscriberID &subscriber_id) {
if (entity_it == entities_.end()) {
continue;
}
auto &entity = entity_it->second;
auto &entity = *entity_it->second;
entity.RemoveSubscriber(subscriber_id);
if (entity.Subscribers().empty()) {
entities_.erase(entity_it);
Expand All @@ -163,7 +182,7 @@ bool SubscriptionIndex::EraseEntry(const std::string &key_id,
const SubscriberID &subscriber_id) {
// Erase the subscriber of all keys.
if (key_id.empty()) {
return subscribers_to_all_.RemoveSubscriber(subscriber_id);
return subscribers_to_all_->RemoveSubscriber(subscriber_id);
}

// Erase keys from the subscriber of individual keys.
Expand All @@ -176,7 +195,7 @@ bool SubscriptionIndex::EraseEntry(const std::string &key_id,
if (object_it == objects.end()) {
auto it = entities_.find(key_id);
if (it != entities_.end()) {
RAY_CHECK(!it->second.Subscribers().contains(subscriber_id));
RAY_CHECK(!it->second->Subscribers().contains(subscriber_id));
}
return false;
}
Expand All @@ -189,7 +208,7 @@ bool SubscriptionIndex::EraseEntry(const std::string &key_id,
auto entity_it = entities_.find(key_id);
// If code reaches this line, that means the object id was in the index.
RAY_CHECK(entity_it != entities_.end());
auto &entity = entity_it->second;
auto &entity = *entity_it->second;
// If code reaches this line, that means the subscriber id was in the index.
RAY_CHECK(entity.RemoveSubscriber(subscriber_id));
if (entity.Subscribers().empty()) {
Expand All @@ -203,7 +222,7 @@ bool SubscriptionIndex::HasKeyId(const std::string &key_id) const {
}

bool SubscriptionIndex::HasSubscriber(const SubscriberID &subscriber_id) const {
if (subscribers_to_all_.Subscribers().contains(subscriber_id)) {
if (subscribers_to_all_->Subscribers().contains(subscriber_id)) {
return true;
}
return subscribers_to_key_id_.contains(subscriber_id);
Expand All @@ -213,6 +232,17 @@ bool SubscriptionIndex::CheckNoLeaks() const {
return entities_.empty() && subscribers_to_key_id_.empty();
}

std::unique_ptr<EntityState> SubscriptionIndex::CreateEntityState() {
switch (channel_type_) {
case rpc::ChannelType::RAY_ERROR_INFO_CHANNEL:
case rpc::ChannelType::RAY_LOG_CHANNEL: {
return std::make_unique<CappedEntityState>();
}
default:
return std::make_unique<BasicEntityState>();
}
}

void SubscriberState::ConnectToSubscriber(const rpc::PubsubLongPollingRequest &request,
rpc::PubsubLongPollingReply *reply,
rpc::SendReplyCallback send_reply_callback) {
Expand Down
53 changes: 44 additions & 9 deletions src/ray/pubsub/publisher.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ class SubscriberState;
/// State for an entity / topic in a pub/sub channel.
class EntityState {
public:
virtual ~EntityState() = default;

/// Publishes the message to subscribers of the entity.
/// Returns true if there are subscribers, returns false otherwise.
bool Publish(const rpc::PubMessage &pub_message);
virtual bool Publish(const rpc::PubMessage &pub_message) = 0;

/// Manages the set of subscribers of this entity.
bool AddSubscriber(SubscriberState *subscriber);
Expand All @@ -54,6 +56,36 @@ class EntityState {
/// Gets the current set of subscribers, keyed by subscriber IDs.
const absl::flat_hash_map<SubscriberID, SubscriberState *> &Subscribers() const;

protected:
// Subscribers of this entity.
// The underlying SubscriberState is owned by Publisher.
absl::flat_hash_map<SubscriberID, SubscriberState *> subscribers_;
};

/// The two implementations of EntityState are BasicEntityState and CappedEntityState.
///
/// BasicEntityState is the simplest. It is used by default.
///
/// CappedEntityState implements a total size cap on the buffered messages. It helps
/// protect certain channels from using too much memory, e.g. channels for logs and
/// error infos. However each CappedEntityState takes up more space than the
/// BasicEntityState, so it is unsuitable when there can be a large number of entities.
/// i.e. CappedEntityState is not suitable for the WORKER_OBJECT_* channels. It is
/// not very benefitial for actor and node info channels either, since only GCS publishes
/// to these channels with small, bounded-size messages.

/// Publishes the message to all subscribers, without size cap on buffered messages.
class BasicEntityState : public EntityState {
public:
bool Publish(const rpc::PubMessage &pub_message) override;
};

/// Publishes the message to all subscribers, and enforce a total size cap on buffered
/// messages.
class CappedEntityState : public EntityState {
public:
bool Publish(const rpc::PubMessage &pub_message) override;

private:
// Tracks inflight messages. The messages have shared ownership by
// individual subscribers, and get deleted after no subscriber has
Expand All @@ -63,19 +95,18 @@ class EntityState {
std::queue<int64_t> message_sizes_;
// Total size of inflight messages.
int64_t total_size_ = 0;

// Subscribers of this entity.
// The underlying SubscriberState is owned by Publisher.
absl::flat_hash_map<SubscriberID, SubscriberState *> subscribers_;
};

/// Per-channel two-way index for subscribers and the keys they subscribe to.
/// Also supports subscribers to all keys in the channel.
class SubscriptionIndex {
public:
SubscriptionIndex() = default;
SubscriptionIndex(rpc::ChannelType channel_type);
~SubscriptionIndex() = default;

SubscriptionIndex(SubscriptionIndex &&) noexcept = default;
SubscriptionIndex &operator=(SubscriptionIndex &&) noexcept = default;

/// Publishes the message to relevant subscribers.
/// Returns true if there are subscribers listening on the entity key of the message,
/// returns false otherwise.
Expand Down Expand Up @@ -112,10 +143,14 @@ class SubscriptionIndex {
bool CheckNoLeaks() const;

private:
std::unique_ptr<EntityState> CreateEntityState();

// Type of channel this index is for.
rpc::ChannelType channel_type_;
// Collection of subscribers that subscribe to all entities of the channel.
EntityState subscribers_to_all_;
std::unique_ptr<EntityState> subscribers_to_all_;
// Mapping from subscribed entity id -> entity state.
absl::flat_hash_map<std::string, EntityState> entities_;
absl::flat_hash_map<std::string, std::unique_ptr<EntityState>> entities_;
// Mapping from subscriber IDs -> subscribed key ids.
// Reverse index of key_id_to_subscribers_.
absl::flat_hash_map<SubscriberID, absl::flat_hash_set<std::string>>
Expand Down Expand Up @@ -283,7 +318,7 @@ class Publisher : public PublisherInterface {
publish_batch_size_(publish_batch_size) {
// Insert index map for each channel.
for (auto type : channels) {
subscription_index_map_.emplace(type, pub_internal::SubscriptionIndex());
subscription_index_map_.emplace(type, type);
}

periodical_runner_->RunFnPeriodically([this] { CheckDeadSubscribers(); },
Expand Down
16 changes: 8 additions & 8 deletions src/ray/pubsub/test/publisher_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ TEST_F(PublisherTest, TestSubscriptionIndexSingeNodeSingleObject) {
/// Test single node id & object id
///
/// oid1 -> [nid1]
SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
subscription_index.AddEntry(oid.Binary(), subscriber);
const auto &subscribers_from_index =
subscription_index.GetSubscriberIdsByKeyId(oid.Binary());
Expand All @@ -127,7 +127,7 @@ TEST_F(PublisherTest, TestSubscriptionIndexMultiNodeSingleObject) {
/// Test single object id & multi nodes
///
/// oid1 -> [nid1~nid5]
SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
const auto oid = ObjectID::FromRandom();
absl::flat_hash_set<NodeID> empty_set;
subscribers_map_.emplace(oid, empty_set);
Expand Down Expand Up @@ -177,7 +177,7 @@ TEST_F(PublisherTest, TestSubscriptionIndexErase) {
///
/// oid1 -> [nid1~nid5]
/// oid2 -> [nid1~nid5]
SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
int total_entries = 6;
int entries_to_delete_at_each_time = 3;
auto oid = ObjectID::FromRandom();
Expand Down Expand Up @@ -226,7 +226,7 @@ TEST_F(PublisherTest, TestSubscriptionIndexEraseMultiSubscribers) {
///
/// Test erase the duplicated entries with multi subscribers.
///
SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
auto oid = ObjectID::FromRandom();
auto oid2 = ObjectID::FromRandom();
absl::flat_hash_set<NodeID> empty_set;
Expand All @@ -250,7 +250,7 @@ TEST_F(PublisherTest, TestSubscriptionIndexEraseSubscriber) {
///
/// Test erase subscriber.
///
SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
auto oid = ObjectID::FromRandom();
auto &subscribers = subscribers_map_[oid];
std::vector<SubscriberID> subscriber_ids;
Expand Down Expand Up @@ -282,7 +282,7 @@ TEST_F(PublisherTest, TestSubscriptionIndexIdempotency) {
auto *subscriber = CreateSubscriber();
auto subscriber_id = subscriber->id();
auto oid = ObjectID::FromRandom();
SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);

// Add the same entry many times.
for (int i = 0; i < 5; i++) {
Expand Down Expand Up @@ -1040,7 +1040,7 @@ class ScopedEntityBufferMaxBytes {
TEST_F(PublisherTest, TestMaxBufferSizePerEntity) {
ScopedEntityBufferMaxBytes max_bytes(10000);

SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
auto job_id = JobID::FromInt(1234);
auto *subscriber = CreateSubscriber();
// Subscribe to job_id.
Expand Down Expand Up @@ -1082,7 +1082,7 @@ TEST_F(PublisherTest, TestMaxBufferSizePerEntity) {
TEST_F(PublisherTest, TestMaxBufferSizeAllEntities) {
ScopedEntityBufferMaxBytes max_bytes(10000);

SubscriptionIndex subscription_index;
SubscriptionIndex subscription_index(rpc::ChannelType::RAY_ERROR_INFO_CHANNEL);
auto *subscriber = CreateSubscriber();
// Subscribe to all entities.
subscription_index.AddEntry("", subscriber);
Expand Down

0 comments on commit 34fb092

Please sign in to comment.