Skip to content

Commit

Permalink
[core] Add 1s timeout in RPC to CoreWorkerService.NumPendingTasks in …
Browse files Browse the repository at this point in the history
…GcsJobManager::HandleGetAllJobInfo (ray-project#46335)

Signed-off-by: Ruiyang Wang <rywang014@gmail.com>
  • Loading branch information
rynewang authored Jul 11, 2024
1 parent 170d108 commit 26b9464
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 23 deletions.
28 changes: 25 additions & 3 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3598,12 +3598,34 @@ def f(signal):
all_job_info = client.get_all_job_info()
assert len(all_job_info) == 1
assert job_id in all_job_info
assert client.get_all_job_info()[job_id].is_running_tasks is True
assert all_job_info[job_id].is_running_tasks is True


if __name__ == "__main__":
import sys
def test_hang_driver_has_no_is_running_task(monkeypatch, ray_start_cluster):
"""
When there's a call to JobInfoGcsService.GetAllJobInfo, GCS sends RPC
CoreWorkerService.NumPendingTasks to all drivers for "is_running_task". Our driver
however has trouble serving such RPC, and GCS should timeout that RPC and unsest the
field.
"""
cluster = ray_start_cluster
cluster.add_node(num_cpus=10)
address = cluster.address

monkeypatch.setenv(
"RAY_testing_asio_delay_us",
"CoreWorkerService.grpc_server.NumPendingTasks=2000000:2000000",
)
ray.init(address=address)

client = ray.worker.global_worker.gcs_client
my_job_id = ray.worker.global_worker.current_job_id
all_job_info = client.get_all_job_info()
assert list(all_job_info.keys()) == [my_job_id]
assert not all_job_info[my_job_id].HasField("is_running_tasks")


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
else:
Expand Down
6 changes: 4 additions & 2 deletions src/mock/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class MockCoreWorkerClientInterface : public ray::pubsub::MockSubscriberClientIn
MOCK_METHOD(void,
NumPendingTasks,
(std::unique_ptr<NumPendingTasksRequest> request,
const ClientCallback<NumPendingTasksReply> &callback),
const ClientCallback<NumPendingTasksReply> &callback,
int64_t timeout_ms),
(override));
MOCK_METHOD(void,
DirectActorCallArgWaitComplete,
Expand Down Expand Up @@ -133,7 +134,8 @@ class MockCoreWorkerClientConfigurableRunningTasks
: num_running_tasks_(num_running_tasks) {}

void NumPendingTasks(std::unique_ptr<NumPendingTasksRequest> request,
const ClientCallback<NumPendingTasksReply> &callback) override {
const ClientCallback<NumPendingTasksReply> &callback,
int64_t timeout_ms = -1) override {
NumPendingTasksReply reply;
reply.set_num_pending_tasks(num_running_tasks_);
callback(Status::OK(), reply);
Expand Down
25 changes: 16 additions & 9 deletions src/ray/gcs/gcs_server/gcs_job_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
job_data_key_to_indices[job_data_key].push_back(i);
}

JobID job_id = data.first;
WorkerID worker_id = WorkerID::FromBinary(data.second.driver_address().worker_id());

// If job is not dead, get is_running_tasks from the core worker for the driver.
Expand All @@ -217,23 +218,29 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request,
// Get is_running_tasks from the core worker for the driver.
auto client = core_worker_clients_.GetOrConnect(data.second.driver_address());
auto request = std::make_unique<rpc::NumPendingTasksRequest>();
RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id;
constexpr int64_t kNumPendingTasksRequestTimeoutMs = 1000;
RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id
<< ", timeout " << kNumPendingTasksRequestTimeoutMs << " ms.";
client->NumPendingTasks(
std::move(request),
[worker_id, reply, i, num_processed_jobs, try_send_reply](
[job_id, worker_id, reply, i, num_processed_jobs, try_send_reply](
const Status &status,
const rpc::NumPendingTasksReply &num_pending_tasks_reply) {
RAY_LOG(DEBUG) << "Received NumPendingTasksReply from worker " << worker_id;
RAY_LOG(DEBUG).WithField(worker_id)
<< "Received NumPendingTasksReply from worker.";
if (!status.ok()) {
RAY_LOG(WARNING) << "Failed to get is_running_tasks from core worker: "
<< status.ToString();
RAY_LOG(WARNING).WithField(job_id).WithField(worker_id)
<< "Failed to get num_pending_tasks from core worker: " << status
<< ", is_running_tasks is unset.";
reply->mutable_job_info_list(i)->clear_is_running_tasks();
} else {
bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0;
reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks);
}
bool is_running_tasks = num_pending_tasks_reply.num_pending_tasks() > 0;
reply->mutable_job_info_list(i)->set_is_running_tasks(is_running_tasks);
(*num_processed_jobs)++;
;
try_send_reply();
});
},
kNumPendingTasksRequestTimeoutMs);
}
i++;
}
Expand Down
3 changes: 2 additions & 1 deletion src/ray/protobuf/gcs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,8 @@ message JobTableData {
// The optional JobInfo from the Ray Job API.
optional JobsAPIInfo job_info = 10;
// Whether this job has running tasks.
bool is_running_tasks = 11;
// In GetAllJobInfo, if GCS can't reach the driver, it will be unset.
optional bool is_running_tasks = 11;
// Address of the driver that started this job.
Address driver_address = 12;
}
Expand Down
14 changes: 6 additions & 8 deletions src/ray/rpc/worker/core_worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ class CoreWorkerClientInterface : public pubsub::SubscriberClientInterface {
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
virtual void NumPendingTasks(std::unique_ptr<NumPendingTasksRequest> request,
const ClientCallback<NumPendingTasksReply> &callback) {}
const ClientCallback<NumPendingTasksReply> &callback,
int64_t timeout_ms = -1) {}

/// Notify a wait has completed for direct actor call arguments.
///
Expand Down Expand Up @@ -392,13 +393,10 @@ class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
}

void NumPendingTasks(std::unique_ptr<NumPendingTasksRequest> request,
const ClientCallback<NumPendingTasksReply> &callback) override {
INVOKE_RPC_CALL(CoreWorkerService,
NumPendingTasks,
*request,
callback,
grpc_client_,
/*method_timeout_ms*/ -1);
const ClientCallback<NumPendingTasksReply> &callback,
int64_t timeout_ms = -1) override {
INVOKE_RPC_CALL(
CoreWorkerService, NumPendingTasks, *request, callback, grpc_client_, timeout_ms);
}

/// Send as many pending tasks as possible. This method is thread-safe.
Expand Down

0 comments on commit 26b9464

Please sign in to comment.