Skip to content

fix code issues in object manager that are reported by scanning tool #3649

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/ray/common/client_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ void ServerConnection<T>::DoAsyncWrites() {
template <class T>
std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
ClientHandler<T> &client_handler, MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label) {
std::shared_ptr<ClientConnection<T>> self(
new ClientConnection(message_handler, std::move(socket), debug_label));
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label,
int64_t error_message_type) {
std::shared_ptr<ClientConnection<T>> self(new ClientConnection(
message_handler, std::move(socket), debug_label, error_message_type));
// Let our manager process our new connection.
client_handler(*self);
return self;
Expand All @@ -197,10 +198,12 @@ std::shared_ptr<ClientConnection<T>> ClientConnection<T>::Create(
template <class T>
ClientConnection<T>::ClientConnection(MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket,
const std::string &debug_label)
const std::string &debug_label,
int64_t error_message_type)
: ServerConnection<T>(std::move(socket)),
message_handler_(message_handler),
debug_label_(debug_label) {}
debug_label_(debug_label),
error_message_type_(error_message_type) {}

template <class T>
const ClientID &ClientConnection<T>::GetClientId() {
Expand Down Expand Up @@ -230,7 +233,7 @@ template <class T>
void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &error) {
if (error) {
// If there was an error, disconnect the client.
read_type_ = static_cast<int64_t>(protocol::MessageType::DisconnectClient);
read_type_ = error_message_type_;
read_length_ = 0;
ProcessMessage(error);
return;
Expand All @@ -251,7 +254,7 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
template <class T>
void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error) {
if (error) {
read_type_ = static_cast<int64_t>(protocol::MessageType::DisconnectClient);
read_type_ = error_message_type_;
}

int64_t start_ms = current_time_ms();
Expand Down
7 changes: 5 additions & 2 deletions src/ray/common/client_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ class ClientConnection : public ServerConnection<T> {
/// \return std::shared_ptr<ClientConnection>.
static std::shared_ptr<ClientConnection<T>> Create(
ClientHandler<T> &new_client_handler, MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label);
boost::asio::basic_stream_socket<T> &&socket, const std::string &debug_label,
int64_t error_message_type);

std::shared_ptr<ClientConnection<T>> shared_ClientConnection_from_this() {
return std::static_pointer_cast<ClientConnection<T>>(shared_from_this());
Expand All @@ -169,7 +170,7 @@ class ClientConnection : public ServerConnection<T> {
/// A private constructor for a node client connection.
ClientConnection(MessageHandler<T> &message_handler,
boost::asio::basic_stream_socket<T> &&socket,
const std::string &debug_label);
const std::string &debug_label, int64_t error_message_type);
/// Process an error from the last operation, then process the message
/// header from the client.
void ProcessMessageHeader(const boost::system::error_code &error);
Expand All @@ -183,6 +184,8 @@ class ClientConnection : public ServerConnection<T> {
MessageHandler<T> message_handler_;
/// A label used for debug messages.
const std::string debug_label_;
/// The value for disconnect client message.
int64_t error_message_type_;
/// Buffers for the current message being read from the client.
int64_t read_version_;
int64_t read_type_;
Expand Down
1 change: 1 addition & 0 deletions src/ray/object_manager/format/object_manager.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ table ObjectInfo {

enum MessageType:int {
ConnectClient = 1,
DisconnectClient,
PushRequest,
PullRequest,
FreeRequest
Expand Down
1 change: 0 additions & 1 deletion src/ray/object_manager/object_buffer_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ void ObjectBufferPool::SealChunk(const ObjectID &object_id, const uint64_t chunk
CreateChunkState::REFERENCED);
create_buffer_state_[object_id].chunk_state[chunk_index] = CreateChunkState::SEALED;
create_buffer_state_[object_id].num_seals_remaining--;
RAY_CHECK(create_buffer_state_[object_id].num_seals_remaining >= 0);
RAY_LOG(DEBUG) << "SealChunk" << object_id << " "
<< create_buffer_state_[object_id].num_seals_remaining;
if (create_buffer_state_[object_id].num_seals_remaining == 0) {
Expand Down
15 changes: 8 additions & 7 deletions src/ray/object_manager/object_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -707,25 +707,26 @@ void ObjectManager::ProcessNewClient(TcpClientConnection &conn) {

void ObjectManager::ProcessClientMessage(std::shared_ptr<TcpClientConnection> &conn,
int64_t message_type, const uint8_t *message) {
switch (message_type) {
case static_cast<int64_t>(object_manager_protocol::MessageType::PushRequest): {
auto message_type_value =
static_cast<object_manager_protocol::MessageType>(message_type);
switch (message_type_value) {
case object_manager_protocol::MessageType::PushRequest: {
ReceivePushRequest(conn, message);
break;
}
case static_cast<int64_t>(object_manager_protocol::MessageType::PullRequest): {
case object_manager_protocol::MessageType::PullRequest: {
ReceivePullRequest(conn, message);
break;
}
case static_cast<int64_t>(object_manager_protocol::MessageType::ConnectClient): {
case object_manager_protocol::MessageType::ConnectClient: {
ConnectClient(conn, message);
break;
}
case static_cast<int64_t>(object_manager_protocol::MessageType::FreeRequest): {
case object_manager_protocol::MessageType::FreeRequest: {
ReceiveFreeRequest(conn, message);
break;
}
case static_cast<int64_t>(protocol::MessageType::DisconnectClient): {
// TODO(hme): Disconnect without depending on the node manager protocol.
case object_manager_protocol::MessageType::DisconnectClient: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the below TDO can be removed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

DisconnectClient(conn, message);
break;
}
Expand Down
7 changes: 4 additions & 3 deletions src/ray/object_manager/test/object_manager_stress_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ class MockServer {
object_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection =
TcpClientConnection::Create(client_handler, message_handler,
std::move(object_manager_socket_), "object manager");
auto new_connection = TcpClientConnection::Create(
client_handler, message_handler, std::move(object_manager_socket_),
"object manager",
static_cast<int64_t>(object_manager::protocol::MessageType::DisconnectClient));
DoAcceptObjectManager();
}

Expand Down
7 changes: 4 additions & 3 deletions src/ray/object_manager/test/object_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ class MockServer {
object_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection =
TcpClientConnection::Create(client_handler, message_handler,
std::move(object_manager_socket_), "object manager");
auto new_connection = TcpClientConnection::Create(
client_handler, message_handler, std::move(object_manager_socket_),
"object manager",
static_cast<int64_t>(object_manager::protocol::MessageType::DisconnectClient));
DoAcceptObjectManager();
}

Expand Down
26 changes: 14 additions & 12 deletions src/ray/raylet/client_connection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ namespace raylet {

class ClientConnectionTest : public ::testing::Test {
public:
ClientConnectionTest() : io_service_(), in_(io_service_), out_(io_service_) {
ClientConnectionTest()
: io_service_(), in_(io_service_), out_(io_service_), error_message_type_(1) {
boost::asio::local::connect_pair(in_, out_);
}

protected:
boost::asio::io_service io_service_;
boost::asio::local::stream_protocol::socket in_;
boost::asio::local::stream_protocol::socket out_;
int64_t error_message_type_;
};

TEST_F(ClientConnectionTest, SimpleSyncWrite) {
Expand All @@ -37,11 +39,11 @@ TEST_F(ClientConnectionTest, SimpleSyncWrite) {
num_messages += 1;
};

auto conn1 = LocalClientConnection::Create(client_handler, message_handler,
std::move(in_), "conn1");
auto conn1 = LocalClientConnection::Create(
client_handler, message_handler, std::move(in_), "conn1", error_message_type_);

auto conn2 = LocalClientConnection::Create(client_handler, message_handler,
std::move(out_), "conn2");
auto conn2 = LocalClientConnection::Create(
client_handler, message_handler, std::move(out_), "conn2", error_message_type_);

RAY_CHECK_OK(conn1->WriteMessage(0, 5, arr));
RAY_CHECK_OK(conn2->WriteMessage(0, 5, arr));
Expand Down Expand Up @@ -83,11 +85,11 @@ TEST_F(ClientConnectionTest, SimpleAsyncWrite) {
}
};

auto writer = LocalClientConnection::Create(client_handler, noop_handler,
std::move(in_), "writer");
auto writer = LocalClientConnection::Create(
client_handler, noop_handler, std::move(in_), "writer", error_message_type_);

reader = LocalClientConnection::Create(client_handler, message_handler, std::move(out_),
"reader");
"reader", error_message_type_);

std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
RAY_CHECK_OK(status);
Expand All @@ -111,8 +113,8 @@ TEST_F(ClientConnectionTest, SimpleAsyncError) {
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {};

auto writer = LocalClientConnection::Create(client_handler, noop_handler,
std::move(in_), "writer");
auto writer = LocalClientConnection::Create(
client_handler, noop_handler, std::move(in_), "writer", error_message_type_);

std::function<void(const ray::Status &)> callback = [](const ray::Status &status) {
ASSERT_TRUE(!status.ok());
Expand All @@ -133,8 +135,8 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) {
std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {};

auto writer = LocalClientConnection::Create(client_handler, noop_handler,
std::move(in_), "writer");
auto writer = LocalClientConnection::Create(
client_handler, noop_handler, std::move(in_), "writer", error_message_type_);

std::function<void(const ray::Status &)> callback =
[writer](const ray::Status &status) {
Expand Down
15 changes: 9 additions & 6 deletions src/ray/raylet/raylet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
};
// Accept a new TCP client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(
client_handler, message_handler, std::move(node_manager_socket_), "node manager");
client_handler, message_handler, std::move(node_manager_socket_), "node manager",
static_cast<int64_t>(protocol::MessageType::DisconnectClient));
}
// We're ready to accept another client.
DoAcceptNodeManager();
Expand All @@ -122,9 +123,10 @@ void Raylet::HandleAcceptObjectManager(const boost::system::error_code &error) {
object_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new TCP client and dispatch it to the node manager.
auto new_connection =
TcpClientConnection::Create(client_handler, message_handler,
std::move(object_manager_socket_), "object manager");
auto new_connection = TcpClientConnection::Create(
client_handler, message_handler, std::move(object_manager_socket_),
"object manager",
static_cast<int64_t>(object_manager::protocol::MessageType::DisconnectClient));
DoAcceptObjectManager();
}

Expand All @@ -144,8 +146,9 @@ void Raylet::HandleAccept(const boost::system::error_code &error) {
node_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = LocalClientConnection::Create(client_handler, message_handler,
std::move(socket_), "worker");
auto new_connection = LocalClientConnection::Create(
client_handler, message_handler, std::move(socket_), "worker",
static_cast<int64_t>(protocol::MessageType::DisconnectClient));
}
// We're ready to accept another client.
DoAccept();
Expand Down
8 changes: 5 additions & 3 deletions src/ray/raylet/worker_pool_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class WorkerPoolMock : public WorkerPool {

class WorkerPoolTest : public ::testing::Test {
public:
WorkerPoolTest() : worker_pool_(), io_service_() {}
WorkerPoolTest() : worker_pool_(), io_service_(), error_message_type_(1) {}

std::shared_ptr<Worker> CreateWorker(pid_t pid,
const Language &language = Language::PYTHON) {
Expand All @@ -46,14 +46,16 @@ class WorkerPoolTest : public ::testing::Test {
HandleMessage(client, message_type, message);
};
boost::asio::local::stream_protocol::socket socket(io_service_);
auto client = LocalClientConnection::Create(client_handler, message_handler,
std::move(socket), "worker");
auto client =
LocalClientConnection::Create(client_handler, message_handler, std::move(socket),
"worker", error_message_type_);
return std::shared_ptr<Worker>(new Worker(pid, language, client));
}

protected:
WorkerPoolMock worker_pool_;
boost::asio::io_service io_service_;
int64_t error_message_type_;

private:
void HandleNewClient(LocalClientConnection &){};
Expand Down