Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

protect raylet against bad messages #4003

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/includes/ray_config.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ cdef extern from "ray/ray_config.h" nogil:
@staticmethod
RayConfig &instance()

int64_t ray_protocol_version() const
int64_t ray_cookie() const

int64_t handler_warning_timeout_ms() const

Expand Down
4 changes: 2 additions & 2 deletions python/ray/includes/ray_config.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ from ray.includes.ray_config cimport RayConfig

cdef class Config:
@staticmethod
def ray_protocol_version():
return RayConfig.instance().ray_protocol_version()
def ray_cookie():
return RayConfig.instance().ray_cookie()

@staticmethod
def handler_warning_timeout_ms():
Expand Down
65 changes: 57 additions & 8 deletions src/ray/common/client_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <stdio.h>
#include <boost/bind.hpp>
#include <sstream>

#include "ray/ray_config.h"
#include "ray/raylet/format/node_manager_generated.h"
Expand Down Expand Up @@ -101,8 +102,8 @@ ray::Status ServerConnection<T>::WriteMessage(int64_t type, int64_t length,
bytes_written_ += length;

std::vector<boost::asio::const_buffer> message_buffers;
auto write_version = RayConfig::instance().ray_protocol_version();
message_buffers.push_back(boost::asio::buffer(&write_version, sizeof(write_version)));
auto write_cookie = RayConfig::instance().ray_cookie();
message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie)));
message_buffers.push_back(boost::asio::buffer(&type, sizeof(type)));
message_buffers.push_back(boost::asio::buffer(&length, sizeof(length)));
message_buffers.push_back(boost::asio::buffer(message, length));
Expand All @@ -117,7 +118,7 @@ void ServerConnection<T>::WriteMessageAsync(
bytes_written_ += length;

auto write_buffer = std::unique_ptr<AsyncWriteBuffer>(new AsyncWriteBuffer());
write_buffer->write_version = RayConfig::instance().ray_protocol_version();
write_buffer->write_cookie = RayConfig::instance().ray_cookie();
write_buffer->write_type = type;
write_buffer->write_length = length;
write_buffer->write_message.resize(length);
Expand Down Expand Up @@ -147,8 +148,8 @@ void ServerConnection<T>::DoAsyncWrites() {
std::vector<boost::asio::const_buffer> message_buffers;
int num_messages = 0;
for (const auto &write_buffer : async_write_queue_) {
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_version,
sizeof(write_buffer->write_version)));
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_cookie,
sizeof(write_buffer->write_cookie)));
message_buffers.push_back(
boost::asio::buffer(&write_buffer->write_type, sizeof(write_buffer->write_type)));
message_buffers.push_back(boost::asio::buffer(&write_buffer->write_length,
Expand Down Expand Up @@ -202,6 +203,7 @@ ClientConnection<T>::ClientConnection(
const std::string &debug_label,
const std::vector<std::string> &message_type_enum_names, int64_t error_message_type)
: ServerConnection<T>(std::move(socket)),
client_id_(ClientID::nil()),
message_handler_(message_handler),
debug_label_(debug_label),
message_type_enum_names_(message_type_enum_names),
Expand All @@ -222,7 +224,7 @@ void ClientConnection<T>::ProcessMessages() {
// Wait for a message header from the client. The message header includes the
// protocol version, the message type, and the length of the message.
std::vector<boost::asio::mutable_buffer> header;
header.push_back(boost::asio::buffer(&read_version_, sizeof(read_version_)));
header.push_back(boost::asio::buffer(&read_cookie_, sizeof(read_cookie_)));
header.push_back(boost::asio::buffer(&read_type_, sizeof(read_type_)));
header.push_back(boost::asio::buffer(&read_length_, sizeof(read_length_)));
boost::asio::async_read(
Expand All @@ -241,8 +243,12 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
return;
}

// If there was no error, make sure the protocol version matches.
RAY_CHECK(read_version_ == RayConfig::instance().ray_protocol_version());
// If there was no error, make sure the ray cookie matches.
if (!CheckRayCookie()) {
ServerConnection<T>::Close();
return;
}

// Resize the message buffer to match the received length.
read_message_.resize(read_length_);
ServerConnection<T>::bytes_read_ += read_length_;
Expand All @@ -253,6 +259,49 @@ void ClientConnection<T>::ProcessMessageHeader(const boost::system::error_code &
shared_ClientConnection_from_this(), boost::asio::placeholders::error));
}

template <class T>
bool ClientConnection<T>::CheckRayCookie() {
if (read_cookie_ == RayConfig::instance().ray_cookie()) {
return true;
}

// Cookie is not matched.
// Only assert if the message is coming from a known remote endpoint,
// which is indicated by a non-nil client ID. This is to protect raylet
// against miscellaneous connections. We did see cases where bad data
// is received from local unknown program which crashes raylet.
std::ostringstream ss;
ss << " ray cookie mismatch for received message. "
<< "received cookie: " << read_cookie_ << ", debug label: " << debug_label_
<< ", remote client ID: " << client_id_;
auto remote_endpoint_info = RemoteEndpointInfo();
if (!remote_endpoint_info.empty()) {
ss << ", remote endpoint info: " << remote_endpoint_info;
}

if (!client_id_.is_nil()) {
// This is from a known client, which indicates a bug.
RAY_LOG(FATAL) << ss.str();
} else {
// It's not from a known client, log this message, and stop processing the connection.
RAY_LOG(WARNING) << ss.str();
}
return false;
}

template <class T>
std::string ClientConnection<T>::RemoteEndpointInfo() {
return std::string();
}

template <>
std::string ClientConnection<boost::asio::ip::tcp>::RemoteEndpointInfo() {
const auto &remote_endpoint =
ServerConnection<boost::asio::ip::tcp>::socket_.remote_endpoint();
return remote_endpoint.address().to_string() + ":" +
std::to_string(remote_endpoint.port());
}

template <class T>
void ClientConnection<T>::ProcessMessage(const boost::system::error_code &error) {
if (error) {
Expand Down
15 changes: 13 additions & 2 deletions src/ray/common/client_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ServerConnection : public std::enable_shared_from_this<ServerConnection<T>

/// A message that is queued for writing asynchronously.
struct AsyncWriteBuffer {
int64_t write_version;
int64_t write_cookie;
int64_t write_type;
uint64_t write_length;
std::vector<uint8_t> write_message;
Expand Down Expand Up @@ -184,6 +184,17 @@ class ClientConnection : public ServerConnection<T> {
/// Process an error from reading the message header, then process the
/// message from the client.
void ProcessMessage(const boost::system::error_code &error);
/// Check if the ray cookie in a received message is correct. Note, if the cookie
/// is wrong and the remote endpoint is known, raylet process will crash. If the remote
/// endpoint is unknown, this method will only print a warning.
///
/// \return If the cookie is correct.
bool CheckRayCookie();
/// Return information about IP and port for the remote endpoint. For local connection
/// this returns an empty string.
///
/// \return Information of remote endpoint.
std::string RemoteEndpointInfo();

/// The ClientID of the remote client.
ClientID client_id_;
Expand All @@ -197,7 +208,7 @@ class ClientConnection : public ServerConnection<T> {
/// The value for disconnect client message.
int64_t error_message_type_;
/// Buffers for the current message being read from the client.
int64_t read_version_;
int64_t read_cookie_;
int64_t read_type_;
uint64_t read_length_;
std::vector<uint8_t> read_message_;
Expand Down
8 changes: 6 additions & 2 deletions src/ray/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
// 1. You must update the file "ray/python/ray/includes/ray_config.pxd".
// 2. You must update the file "ray/python/ray/includes/ray_config.pxi".

/// In theory, this is used to detect Ray version mismatches.
RAY_CONFIG(int64_t, ray_protocol_version, 0x0000000000000000);
/// In theory, this is used to detect Ray cookie mismatches.
/// This magic number (hex for "RAY") is used instead of zero, rationale is
/// that it could still be possible that some random program sends an int64_t
/// which is zero, but it's much less likely that a program sends this
/// particular magic number.
RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000);

/// The duration that a single handler on the event loop can take before a
/// warning is logged that the handler is taking too long.
Expand Down
43 changes: 43 additions & 0 deletions src/ray/raylet/client_connection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ class ClientConnectionTest : public ::testing::Test {
boost::asio::local::connect_pair(in_, out_);
}

ray::Status WriteBadMessage(std::shared_ptr<ray::LocalClientConnection> conn,
int64_t type, int64_t length, const uint8_t *message) {
std::vector<boost::asio::const_buffer> message_buffers;
auto write_cookie = 123456; // incorrect version.
message_buffers.push_back(boost::asio::buffer(&write_cookie, sizeof(write_cookie)));
message_buffers.push_back(boost::asio::buffer(&type, sizeof(type)));
message_buffers.push_back(boost::asio::buffer(&length, sizeof(length)));
message_buffers.push_back(boost::asio::buffer(message, length));
return conn->WriteBuffer(message_buffers);
}

protected:
boost::asio::io_service io_service_;
boost::asio::local::stream_protocol::socket in_;
Expand Down Expand Up @@ -147,6 +158,38 @@ TEST_F(ClientConnectionTest, CallbackWithSharedRefDoesNotLeakConnection) {
io_service_.run();
}

TEST_F(ClientConnectionTest, ProcessBadMessage) {
const uint8_t arr[5] = {1, 2, 3, 4, 5};
int num_messages = 0;

ClientHandler<boost::asio::local::stream_protocol> client_handler =
[](LocalClientConnection &client) {};

MessageHandler<boost::asio::local::stream_protocol> message_handler =
[&arr, &num_messages](std::shared_ptr<LocalClientConnection> client,
int64_t message_type, const uint8_t *message) {
ASSERT_TRUE(!std::memcmp(arr, message, 5));
num_messages += 1;
};

auto writer = LocalClientConnection::Create(
client_handler, message_handler, std::move(in_), "writer", {}, error_message_type_);

auto reader =
LocalClientConnection::Create(client_handler, message_handler, std::move(out_),
"reader", {}, error_message_type_);

// If client ID is set, bad message would crash the test.
// reader->SetClientID(UniqueID::from_random());

// Intentionally write a message with incorrect cookie.
// Verify it won't crash as long as client ID is not set.
RAY_CHECK_OK(WriteBadMessage(writer, 0, 5, arr));
reader->ProcessMessages();
io_service_.run();
ASSERT_EQ(num_messages, 0);
}

} // namespace raylet

} // namespace ray
Expand Down
7 changes: 7 additions & 0 deletions src/ray/raylet/format/node_manager.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ enum MessageType:int {
PushProfileEventsRequest,
// Free the objects in objects store.
FreeObjectsInObjectStoreRequest,
// A node manager requests to connect to another node manager.
ConnectClient,
}

table TaskExecutionSpecification {
Expand Down Expand Up @@ -204,3 +206,8 @@ table FreeObjectsRequest {
// List of object ids we'll delete from object store.
object_ids: [string];
}

table ConnectClient {
// ID of the connecting client.
client_id: string;
}
48 changes: 36 additions & 12 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
std::shared_ptr<gcs::AsyncGcsClient> gcs_client,
std::shared_ptr<ObjectDirectoryInterface> object_directory,
plasma::PlasmaClient &store_client)
: io_service_(io_service),
: client_id_(gcs_client->client_table().GetLocalClientId()),
io_service_(io_service),
object_manager_(object_manager),
store_client_(store_client),
gcs_client_(std::move(gcs_client)),
Expand Down Expand Up @@ -338,13 +339,8 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
}

// Establish a new NodeManager connection to this GCS client.
RAY_LOG(DEBUG) << "[ClientAdded] Trying to connect to client " << client_id << " at "
<< client_data.node_manager_address << ":"
<< client_data.node_manager_port;

boost::asio::ip::tcp::socket socket(io_service_);
auto status =
TcpConnect(socket, client_data.node_manager_address, client_data.node_manager_port);
auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address,
client_data.node_manager_port);
// A disconnected client has 2 entries in the client table (one for being
// inserted and one for being removed). When a new raylet starts, ClientAdded
// will be called with the disconnected client's first entry, which will cause
Expand All @@ -357,15 +353,38 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
return;
}

// The client is connected.
auto server_conn = TcpServerConnection::Create(std::move(socket));
remote_server_connections_.emplace(client_id, std::move(server_conn));

ResourceSet resources_total(client_data.resources_total_label,
client_data.resources_total_capacity);
cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total));
}

ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id,
const std::string &client_address,
int32_t client_port) {
// Establish a new NodeManager connection to this GCS client.
RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at "
<< client_address << ":" << client_port;

boost::asio::ip::tcp::socket socket(io_service_);
RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port));

// The client is connected, now send a connect message to remote node manager.
auto server_conn = TcpServerConnection::Create(std::move(socket));

// Prepare client connection info buffer
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_));
fbb.Finish(message);
// Send synchronously.
// TODO(swang): Make this a WriteMessageAsync.
RAY_RETURN_NOT_OK(server_conn->WriteMessage(
static_cast<int64_t>(protocol::MessageType::ConnectClient), fbb.GetSize(),
fbb.GetBufferPointer()));

remote_server_connections_.emplace(client_id, std::move(server_conn));
return ray::Status::OK();
}

void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// TODO(swang): If we receive a notification for our own death, clean up and
// exit immediately.
Expand Down Expand Up @@ -1007,6 +1026,11 @@ void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_cl
<< protocol::EnumNameMessageType(message_type_value) << "("
<< message_type << ") from node manager";
switch (message_type_value) {
case protocol::MessageType::ConnectClient: {
auto message = flatbuffers::GetRoot<protocol::ConnectClient>(message_data);
auto client_id = from_flatbuf(*message->client_id());
node_manager_client.SetClientID(client_id);
} break;
case protocol::MessageType::ForwardTaskRequest: {
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
TaskID task_id = from_flatbuf(*message->task_id());
Expand Down
12 changes: 12 additions & 0 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,18 @@ class NodeManager {
void HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
bool intentional_disconnect);

/// connect to a remote node manager.
///
/// \param client_id The client ID for the remote node manager.
/// \param client_address The IP address for the remote node manager.
/// \param client_port The listening port for the remote node manager.
/// \return True if the connect succeeds.
ray::Status ConnectRemoteNodeManager(const ClientID &client_id,
const std::string &client_address,
int32_t client_port);

// GCS client ID for this node.
ClientID client_id_;
boost::asio::io_service &io_service_;
ObjectManager &object_manager_;
/// A Plasma object store client. This is used exclusively for creating new
Expand Down
10 changes: 5 additions & 5 deletions src/ray/raylet/raylet_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,12 @@ ray::Status RayletConnection::Disconnect() {

ray::Status RayletConnection::ReadMessage(MessageType type,
std::unique_ptr<uint8_t[]> &message) {
int64_t version;
int64_t cookie;
int64_t type_field;
int64_t length;
int closed = read_bytes(conn_, (uint8_t *)&version, sizeof(version));
int closed = read_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
if (closed) goto disconnected;
RAY_CHECK(version == RayConfig::instance().ray_protocol_version());
RAY_CHECK(cookie == RayConfig::instance().ray_cookie());
closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field));
if (closed) goto disconnected;
closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length));
Expand Down Expand Up @@ -175,13 +175,13 @@ ray::Status RayletConnection::ReadMessage(MessageType type,
ray::Status RayletConnection::WriteMessage(MessageType type,
flatbuffers::FlatBufferBuilder *fbb) {
std::unique_lock<std::mutex> guard(write_mutex_);
int64_t version = RayConfig::instance().ray_protocol_version();
int64_t cookie = RayConfig::instance().ray_cookie();
int64_t length = fbb ? fbb->GetSize() : 0;
uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr;
int64_t type_field = static_cast<int64_t>(type);
auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly.");
int closed;
closed = write_bytes(conn_, (uint8_t *)&version, sizeof(version));
closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
if (closed) return io_error;
closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field));
if (closed) return io_error;
Expand Down
Loading