From 29fa5cddefd567b64f35a1617eb79783e6ebd659 Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Wed, 26 Jun 2024 10:06:26 -0700 Subject: [PATCH] [Core] ray list tasks filter state and name on gcs side (#46270) Signed-off-by: Jiajun Yao --- python/ray/tests/test_state_api.py | 32 +++++--- python/ray/util/state/common.py | 39 +++++----- python/ray/util/state/state_manager.py | 4 + src/ray/gcs/gcs_server/gcs_task_manager.cc | 30 +++++++- .../gcs_server/test/gcs_task_manager_test.cc | 76 ++++++++++++++++--- src/ray/gcs/pb_util.h | 39 ++-------- src/ray/protobuf/gcs.proto | 15 +--- src/ray/protobuf/gcs_service.proto | 2 + 8 files changed, 150 insertions(+), 87 deletions(-) diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index eecd3bcd7eb9..d687199d024a 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -250,8 +250,8 @@ def generate_task_event( ) state_updates = TaskStateUpdate( node_id=node_id, + state_ts={state: 1}, ) - setattr(state_updates, TaskStatus.Name(state).lower() + "_ts", 1) return TaskEvents( task_id=id, job_id=job_id, @@ -1007,10 +1007,12 @@ async def test_api_manager_list_tasks_events(state_api_manager): second = int(1e9) state_updates = TaskStateUpdate( node_id=node_id.binary(), - pending_args_avail_ts=current, - submitted_to_worker_ts=current + second, - running_ts=current + (2 * second), - finished_ts=current + (3 * second), + state_ts={ + TaskStatus.PENDING_ARGS_AVAIL: current, + TaskStatus.SUBMITTED_TO_WORKER: current + second, + TaskStatus.RUNNING: current + (2 * second), + TaskStatus.FINISHED: current + (3 * second), + }, ) """ @@ -1056,9 +1058,11 @@ async def test_api_manager_list_tasks_events(state_api_manager): """ state_updates = TaskStateUpdate( node_id=node_id.binary(), - pending_args_avail_ts=current, - submitted_to_worker_ts=current + second, - running_ts=current + (2 * second), + state_ts={ + TaskStatus.PENDING_ARGS_AVAIL: current, + TaskStatus.SUBMITTED_TO_WORKER: current + second, + TaskStatus.RUNNING: current + (2 * second), + }, ) events = TaskEvents( task_id=id, @@ -1077,8 +1081,10 @@ async def test_api_manager_list_tasks_events(state_api_manager): Test None of start & end time is updated. """ state_updates = TaskStateUpdate( - pending_args_avail_ts=current, - submitted_to_worker_ts=current + second, + state_ts={ + TaskStatus.PENDING_ARGS_AVAIL: current, + TaskStatus.SUBMITTED_TO_WORKER: current + second, + }, ) events = TaskEvents( task_id=id, @@ -2424,7 +2430,11 @@ def verify(): for task in tasks: assert task["job_id"] == job_id - tasks = list_tasks(filters=[("name", "=", "f_0")]) + tasks = list_tasks(filters=[("name", "=", "f_0")], limit=1) + assert len(tasks) == 1 + + # using limit to make sure state filtering is done on the gcs side + tasks = list_tasks(filters=[("STATE", "=", "PENDING_ARGS_AVAIL")], limit=1) assert len(tasks) == 1 return True diff --git a/python/ray/util/state/common.py b/python/ray/util/state/common.py index 9dbb56286c61..898e4caeda7e 100644 --- a/python/ray/util/state/common.py +++ b/python/ray/util/state/common.py @@ -1583,24 +1583,27 @@ def protobuf_to_task_state_dict(message: TaskEvents) -> dict: task_state["end_time_ms"] = None events = [] - for state in TaskStatus.keys(): - key = f"{state.lower()}_ts" - if key in state_updates: - # timestamp is recorded as nanosecond from the backend. - # We need to convert it to the second. - ts_ms = int(state_updates[key]) // 1e6 - events.append( - { - "state": state, - "created_ms": ts_ms, - } - ) - if state == "PENDING_ARGS_AVAIL": - task_state["creation_time_ms"] = ts_ms - if state == "RUNNING": - task_state["start_time_ms"] = ts_ms - if state == "FINISHED" or state == "FAILED": - task_state["end_time_ms"] = ts_ms + if "state_ts" in state_updates: + state_ts = state_updates["state_ts"] + for state_name, state in TaskStatus.items(): + # state_ts is Map[str, str] after protobuf MessageToDict + key = str(state) + if key in state_ts: + # timestamp is recorded as nanosecond from the backend. + # We need to convert it to the second. + ts_ms = int(state_ts[key]) // 1e6 + events.append( + { + "state": state_name, + "created_ms": ts_ms, + } + ) + if state == TaskStatus.PENDING_ARGS_AVAIL: + task_state["creation_time_ms"] = ts_ms + if state == TaskStatus.RUNNING: + task_state["start_time_ms"] = ts_ms + if state == TaskStatus.FINISHED or state == TaskStatus.FAILED: + task_state["end_time_ms"] = ts_ms task_state["events"] = events if len(events) > 0: diff --git a/python/ray/util/state/state_manager.py b/python/ray/util/state/state_manager.py index 6abd7ea51b85..8bb192a1e12d 100644 --- a/python/ray/util/state/state_manager.py +++ b/python/ray/util/state/state_manager.py @@ -307,6 +307,10 @@ async def get_all_task_info( req_filters.job_id = JobID(hex_to_binary(value)).binary() elif key == "task_id": req_filters.task_ids.append(TaskID(hex_to_binary(value)).binary()) + elif key == "name": + req_filters.name = value + elif key == "state": + req_filters.state = value else: continue diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.cc b/src/ray/gcs/gcs_server/gcs_task_manager.cc index 69904f252c3e..5986f0897cd9 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_task_manager.cc @@ -14,6 +14,7 @@ #include "ray/gcs/gcs_server/gcs_task_manager.h" +#include "absl/strings/match.h" #include "ray/common/ray_config.h" #include "ray/common/status.h" @@ -112,7 +113,7 @@ void GcsTaskManager::GcsTaskManagerStorage::MarkTaskAttemptFailedIfNeeded( // We could mark the task as failed even if might not have state updates yet (i.e. only // profiling events are reported). auto state_updates = task_events.mutable_state_updates(); - state_updates->set_failed_ts(failed_ts); + (*state_updates->mutable_state_ts())[ray::rpc::TaskStatus::FAILED] = failed_ts; state_updates->mutable_error_info()->CopyFrom(error_info); } @@ -419,10 +420,35 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request, return false; } - if (filters.has_name() && task_event.task_info().name() != filters.name()) { + if (filters.has_name() && + !absl::EqualsIgnoreCase(task_event.task_info().name(), filters.name())) { return false; } + if (filters.has_state()) { + const google::protobuf::EnumDescriptor *task_status_descriptor = + ray::rpc::TaskStatus_descriptor(); + + // Figure out the latest state of a task. + ray::rpc::TaskStatus state = ray::rpc::TaskStatus::NIL; + if (task_event.has_state_updates()) { + for (int i = task_status_descriptor->value_count() - 1; i >= 0; --i) { + if (task_event.state_updates().state_ts().contains( + task_status_descriptor->value(i)->number())) { + state = static_cast( + task_status_descriptor->value(i)->number()); + break; + } + } + } + + if (!absl::EqualsIgnoreCase( + filters.state(), + task_status_descriptor->FindValueByNumber(state)->name())) { + return false; + } + } + return true; }; diff --git a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc index 865c3547bd82..fca9d1be4a5e 100644 --- a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc @@ -127,7 +127,8 @@ class GcsTaskManagerTest : public ::testing::Test { int64_t limit = -1, bool exclude_driver = true, const std::string &name = "", - const ActorID &actor_id = ActorID::Nil()) { + const ActorID &actor_id = ActorID::Nil(), + const std::string &state = "") { rpc::GetTaskEventsRequest request; rpc::GetTaskEventsReply reply; std::promise promise; @@ -142,6 +143,10 @@ class GcsTaskManagerTest : public ::testing::Test { request.mutable_filters()->set_name(name); } + if (!state.empty()) { + request.mutable_filters()->set_state(state); + } + if (!actor_id.IsNil()) { request.mutable_filters()->set_actor_id(actor_id.Binary()); } @@ -576,6 +581,24 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) { SyncAddTaskEventData(data); } + // A task event with state transitions. + { + auto task_ids = GenTaskIDs(1); + auto task_info = GenTaskInfo(JobID::FromInt(1), TaskID::Nil(), rpc::NORMAL_TASK); + auto events = + GenTaskEvents(task_ids, + /* attempt_number */ + 0, + /* job_id */ 1, + absl::nullopt, + GenStateUpdate({{rpc::TaskStatus::PENDING_NODE_ASSIGNMENT, 1}, + {rpc::TaskStatus::RUNNING, 5}}, + WorkerID::Nil()), + task_info); + auto data = Mocker::GenTaskEventsData(events); + SyncAddTaskEventData(data); + } + auto reply_name = SyncGetTaskEvents({}, /* job_id */ absl::nullopt, /* limit */ -1, @@ -598,6 +621,33 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) { "task_name", actor_id); EXPECT_EQ(reply_both_and.events_by_task_size(), 0); + + auto reply_state = SyncGetTaskEvents({}, + /* job_id */ absl::nullopt, + /* limit */ -1, + /* exclude_driver */ false, + /* name */ "", + ActorID::Nil(), + "RUnnING"); + EXPECT_EQ(reply_state.events_by_task_size(), 1); + + reply_state = SyncGetTaskEvents({}, + /* job_id */ absl::nullopt, + /* limit */ -1, + /* exclude_driver */ false, + /* name */ "", + ActorID::Nil(), + "NIL"); + EXPECT_EQ(reply_state.events_by_task_size(), 2); + + reply_state = SyncGetTaskEvents({}, + /* job_id */ absl::nullopt, + /* limit */ -1, + /* exclude_driver */ false, + /* name */ "", + ActorID::Nil(), + "PENDING_NODE_ASSIGNMENT"); + EXPECT_EQ(reply_state.events_by_task_size(), 0); } TEST_F(GcsTaskManagerTest, TestMarkTaskAttemptFailedIfNeeded) { @@ -623,22 +673,22 @@ TEST_F(GcsTaskManagerTest, TestMarkTaskAttemptFailedIfNeeded) { { auto reply = SyncGetTaskEvents({tasks_running}); auto task_event = *(reply.events_by_task().begin()); - EXPECT_EQ(task_event.state_updates().failed_ts(), 4); + EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), 4); } // Check task attempt failed event is not overriding failed tasks. { auto reply = SyncGetTaskEvents({tasks_failed}); auto task_event = *(reply.events_by_task().begin()); - EXPECT_EQ(task_event.state_updates().failed_ts(), 3); + EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), 3); } // Check task attempt failed event is not overriding finished tasks. { auto reply = SyncGetTaskEvents({tasks_finished}); auto task_event = *(reply.events_by_task().begin()); - EXPECT_FALSE(task_event.state_updates().has_failed_ts()); - EXPECT_EQ(task_event.state_updates().finished_ts(), 2); + EXPECT_FALSE(task_event.state_updates().state_ts().contains(rpc::TaskStatus::FAILED)); + EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FINISHED), 2); } } @@ -690,7 +740,8 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) { auto reply = SyncGetTaskEvents(tasks); EXPECT_EQ(reply.events_by_task_size(), 10); for (const auto &task_event : reply.events_by_task()) { - EXPECT_EQ(task_event.state_updates().failed_ts(), /* 5 ms to ns */ 5 * 1000 * 1000); + EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), + /* 5 ms to ns */ 5 * 1000 * 1000); EXPECT_TRUE(task_event.state_updates().has_error_info()); EXPECT_TRUE(task_event.state_updates().error_info().error_type() == rpc::ErrorType::WORKER_DIED); @@ -706,8 +757,9 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) { auto reply = SyncGetTaskEvents(tasks); EXPECT_EQ(reply.events_by_task_size(), 10); for (const auto &task_event : reply.events_by_task()) { - EXPECT_EQ(task_event.state_updates().finished_ts(), 2); - EXPECT_FALSE(task_event.state_updates().has_failed_ts()); + EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FINISHED), 2); + EXPECT_FALSE( + task_event.state_updates().state_ts().contains(rpc::TaskStatus::FAILED)); } } @@ -717,7 +769,7 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) { auto reply = SyncGetTaskEvents(tasks); EXPECT_EQ(reply.events_by_task_size(), 10); for (const auto &task_event : reply.events_by_task()) { - EXPECT_EQ(task_event.state_updates().failed_ts(), 3); + EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), 3); } } @@ -728,8 +780,10 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) { auto reply = SyncGetTaskEvents(tasks); EXPECT_EQ(reply.events_by_task_size(), 5); for (const auto &task_event : reply.events_by_task()) { - EXPECT_FALSE(task_event.state_updates().has_failed_ts()); - EXPECT_FALSE(task_event.state_updates().has_finished_ts()); + EXPECT_FALSE( + task_event.state_updates().state_ts().contains(rpc::TaskStatus::FAILED)); + EXPECT_FALSE( + task_event.state_updates().state_ts().contains(rpc::TaskStatus::FINISHED)); } } } diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index a329dcb9319b..f8c24cf2d48f 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -277,7 +277,8 @@ inline bool IsTaskTerminated(const rpc::TaskEvents &task_event) { } const auto &state_updates = task_event.state_updates(); - return state_updates.has_finished_ts() || state_updates.has_failed_ts(); + return state_updates.state_ts().contains(rpc::TaskStatus::FINISHED) || + state_updates.state_ts().contains(rpc::TaskStatus::FAILED); } inline size_t NumProfileEvents(const rpc::TaskEvents &task_event) { @@ -308,7 +309,7 @@ inline bool IsTaskFinished(const rpc::TaskEvents &task_event) { } const auto &state_updates = task_event.state_updates(); - return state_updates.has_finished_ts(); + return state_updates.state_ts().contains(rpc::TaskStatus::FINISHED); } /// Fill the rpc::TaskStateUpdate with the timestamps according to the status change. @@ -319,39 +320,11 @@ inline bool IsTaskFinished(const rpc::TaskEvents &task_event) { inline void FillTaskStatusUpdateTime(const ray::rpc::TaskStatus &task_status, int64_t timestamp, ray::rpc::TaskStateUpdate *state_updates) { - switch (task_status) { - case rpc::TaskStatus::PENDING_ARGS_AVAIL: { - state_updates->set_pending_args_avail_ts(timestamp); - break; - } - case rpc::TaskStatus::SUBMITTED_TO_WORKER: { - state_updates->set_submitted_to_worker_ts(timestamp); - break; - } - case rpc::TaskStatus::PENDING_NODE_ASSIGNMENT: { - state_updates->set_pending_node_assignment_ts(timestamp); - break; - } - case rpc::TaskStatus::FINISHED: { - state_updates->set_finished_ts(timestamp); - break; - } - case rpc::TaskStatus::FAILED: { - state_updates->set_failed_ts(timestamp); - break; - } - case rpc::TaskStatus::RUNNING: { - state_updates->set_running_ts(timestamp); - break; - } - case rpc::TaskStatus::NIL: { + if (task_status == rpc::TaskStatus::NIL) { // Not status change. - break; - } - default: { - UNREACHABLE; - } + return; } + (*state_updates->mutable_state_ts())[task_status] = timestamp; } inline std::string FormatPlacementGroupLabelName(const std::string &pg_id) { diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index ce11e822d26b..3a02bce9642c 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -216,18 +216,6 @@ message TaskLogInfo { message TaskStateUpdate { // Node that runs the task. optional bytes node_id = 1; - // Timestamp when status changes to PENDING_ARGS_AVAIL. - optional int64 pending_args_avail_ts = 2; - // Timestamp when status changes to PENDING_NODE_ASSIGNMENT. - optional int64 pending_node_assignment_ts = 3; - // Timestamp when status changes to SUBMITTED_TO_WORKER. - optional int64 submitted_to_worker_ts = 4; - // Timestamp when status changes to RUNNING. - optional int64 running_ts = 5; - // Timestamp when status changes to FINISHED. - optional int64 finished_ts = 6; - // Timestamp when status changes to FAILED. - optional int64 failed_ts = 7; // Worker that runs the task. optional bytes worker_id = 8; // Task faulure info. @@ -240,6 +228,9 @@ message TaskStateUpdate { optional int32 worker_pid = 12; // Is task paused by debugger. optional bool is_debugger_paused = 13; + // Key is the integer value of TaskStatus enum (protobuf doesn't support Enum as key). + // Value is the timestamp when status changes to the target status indicated by the key. + map state_ts = 14; } // Represents events and state changes from a single task run. diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index dfe625b99ef8..98baef63afa7 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -719,6 +719,8 @@ message GetTaskEventsRequest { optional string name = 4; // True if task events from driver (only profiling events) should be excluded. optional bool exclude_driver = 5; + // Latest state of the task. + optional string state = 6; } // Maximum number of TaskEvents to return.