Skip to content

Commit

Permalink
[Core] ray list tasks filter state and name on gcs side (ray-project#…
Browse files Browse the repository at this point in the history
…46270)

Signed-off-by: Jiajun Yao <jeromeyjj@gmail.com>
  • Loading branch information
jjyao authored Jun 26, 2024
1 parent f21b7f8 commit 29fa5cd
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 87 deletions.
32 changes: 21 additions & 11 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def generate_task_event(
)
state_updates = TaskStateUpdate(
node_id=node_id,
state_ts={state: 1},
)
setattr(state_updates, TaskStatus.Name(state).lower() + "_ts", 1)
return TaskEvents(
task_id=id,
job_id=job_id,
Expand Down Expand Up @@ -1007,10 +1007,12 @@ async def test_api_manager_list_tasks_events(state_api_manager):
second = int(1e9)
state_updates = TaskStateUpdate(
node_id=node_id.binary(),
pending_args_avail_ts=current,
submitted_to_worker_ts=current + second,
running_ts=current + (2 * second),
finished_ts=current + (3 * second),
state_ts={
TaskStatus.PENDING_ARGS_AVAIL: current,
TaskStatus.SUBMITTED_TO_WORKER: current + second,
TaskStatus.RUNNING: current + (2 * second),
TaskStatus.FINISHED: current + (3 * second),
},
)

"""
Expand Down Expand Up @@ -1056,9 +1058,11 @@ async def test_api_manager_list_tasks_events(state_api_manager):
"""
state_updates = TaskStateUpdate(
node_id=node_id.binary(),
pending_args_avail_ts=current,
submitted_to_worker_ts=current + second,
running_ts=current + (2 * second),
state_ts={
TaskStatus.PENDING_ARGS_AVAIL: current,
TaskStatus.SUBMITTED_TO_WORKER: current + second,
TaskStatus.RUNNING: current + (2 * second),
},
)
events = TaskEvents(
task_id=id,
Expand All @@ -1077,8 +1081,10 @@ async def test_api_manager_list_tasks_events(state_api_manager):
Test None of start & end time is updated.
"""
state_updates = TaskStateUpdate(
pending_args_avail_ts=current,
submitted_to_worker_ts=current + second,
state_ts={
TaskStatus.PENDING_ARGS_AVAIL: current,
TaskStatus.SUBMITTED_TO_WORKER: current + second,
},
)
events = TaskEvents(
task_id=id,
Expand Down Expand Up @@ -2424,7 +2430,11 @@ def verify():
for task in tasks:
assert task["job_id"] == job_id

tasks = list_tasks(filters=[("name", "=", "f_0")])
tasks = list_tasks(filters=[("name", "=", "f_0")], limit=1)
assert len(tasks) == 1

# using limit to make sure state filtering is done on the gcs side
tasks = list_tasks(filters=[("STATE", "=", "PENDING_ARGS_AVAIL")], limit=1)
assert len(tasks) == 1

return True
Expand Down
39 changes: 21 additions & 18 deletions python/ray/util/state/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,24 +1583,27 @@ def protobuf_to_task_state_dict(message: TaskEvents) -> dict:
task_state["end_time_ms"] = None
events = []

for state in TaskStatus.keys():
key = f"{state.lower()}_ts"
if key in state_updates:
# timestamp is recorded as nanosecond from the backend.
# We need to convert it to the second.
ts_ms = int(state_updates[key]) // 1e6
events.append(
{
"state": state,
"created_ms": ts_ms,
}
)
if state == "PENDING_ARGS_AVAIL":
task_state["creation_time_ms"] = ts_ms
if state == "RUNNING":
task_state["start_time_ms"] = ts_ms
if state == "FINISHED" or state == "FAILED":
task_state["end_time_ms"] = ts_ms
if "state_ts" in state_updates:
state_ts = state_updates["state_ts"]
for state_name, state in TaskStatus.items():
# state_ts is Map[str, str] after protobuf MessageToDict
key = str(state)
if key in state_ts:
# timestamp is recorded as nanosecond from the backend.
# We need to convert it to the second.
ts_ms = int(state_ts[key]) // 1e6
events.append(
{
"state": state_name,
"created_ms": ts_ms,
}
)
if state == TaskStatus.PENDING_ARGS_AVAIL:
task_state["creation_time_ms"] = ts_ms
if state == TaskStatus.RUNNING:
task_state["start_time_ms"] = ts_ms
if state == TaskStatus.FINISHED or state == TaskStatus.FAILED:
task_state["end_time_ms"] = ts_ms

task_state["events"] = events
if len(events) > 0:
Expand Down
4 changes: 4 additions & 0 deletions python/ray/util/state/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ async def get_all_task_info(
req_filters.job_id = JobID(hex_to_binary(value)).binary()
elif key == "task_id":
req_filters.task_ids.append(TaskID(hex_to_binary(value)).binary())
elif key == "name":
req_filters.name = value
elif key == "state":
req_filters.state = value
else:
continue

Expand Down
30 changes: 28 additions & 2 deletions src/ray/gcs/gcs_server/gcs_task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "ray/gcs/gcs_server/gcs_task_manager.h"

#include "absl/strings/match.h"
#include "ray/common/ray_config.h"
#include "ray/common/status.h"

Expand Down Expand Up @@ -112,7 +113,7 @@ void GcsTaskManager::GcsTaskManagerStorage::MarkTaskAttemptFailedIfNeeded(
// We could mark the task as failed even if might not have state updates yet (i.e. only
// profiling events are reported).
auto state_updates = task_events.mutable_state_updates();
state_updates->set_failed_ts(failed_ts);
(*state_updates->mutable_state_ts())[ray::rpc::TaskStatus::FAILED] = failed_ts;
state_updates->mutable_error_info()->CopyFrom(error_info);
}

Expand Down Expand Up @@ -419,10 +420,35 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request,
return false;
}

if (filters.has_name() && task_event.task_info().name() != filters.name()) {
if (filters.has_name() &&
!absl::EqualsIgnoreCase(task_event.task_info().name(), filters.name())) {
return false;
}

if (filters.has_state()) {
const google::protobuf::EnumDescriptor *task_status_descriptor =
ray::rpc::TaskStatus_descriptor();

// Figure out the latest state of a task.
ray::rpc::TaskStatus state = ray::rpc::TaskStatus::NIL;
if (task_event.has_state_updates()) {
for (int i = task_status_descriptor->value_count() - 1; i >= 0; --i) {
if (task_event.state_updates().state_ts().contains(
task_status_descriptor->value(i)->number())) {
state = static_cast<ray::rpc::TaskStatus>(
task_status_descriptor->value(i)->number());
break;
}
}
}

if (!absl::EqualsIgnoreCase(
filters.state(),
task_status_descriptor->FindValueByNumber(state)->name())) {
return false;
}
}

return true;
};

Expand Down
76 changes: 65 additions & 11 deletions src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class GcsTaskManagerTest : public ::testing::Test {
int64_t limit = -1,
bool exclude_driver = true,
const std::string &name = "",
const ActorID &actor_id = ActorID::Nil()) {
const ActorID &actor_id = ActorID::Nil(),
const std::string &state = "") {
rpc::GetTaskEventsRequest request;
rpc::GetTaskEventsReply reply;
std::promise<bool> promise;
Expand All @@ -142,6 +143,10 @@ class GcsTaskManagerTest : public ::testing::Test {
request.mutable_filters()->set_name(name);
}

if (!state.empty()) {
request.mutable_filters()->set_state(state);
}

if (!actor_id.IsNil()) {
request.mutable_filters()->set_actor_id(actor_id.Binary());
}
Expand Down Expand Up @@ -576,6 +581,24 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) {
SyncAddTaskEventData(data);
}

// A task event with state transitions.
{
auto task_ids = GenTaskIDs(1);
auto task_info = GenTaskInfo(JobID::FromInt(1), TaskID::Nil(), rpc::NORMAL_TASK);
auto events =
GenTaskEvents(task_ids,
/* attempt_number */
0,
/* job_id */ 1,
absl::nullopt,
GenStateUpdate({{rpc::TaskStatus::PENDING_NODE_ASSIGNMENT, 1},
{rpc::TaskStatus::RUNNING, 5}},
WorkerID::Nil()),
task_info);
auto data = Mocker::GenTaskEventsData(events);
SyncAddTaskEventData(data);
}

auto reply_name = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
Expand All @@ -598,6 +621,33 @@ TEST_F(GcsTaskManagerTest, TestGetTaskEventsFilters) {
"task_name",
actor_id);
EXPECT_EQ(reply_both_and.events_by_task_size(), 0);

auto reply_state = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
/* name */ "",
ActorID::Nil(),
"RUnnING");
EXPECT_EQ(reply_state.events_by_task_size(), 1);

reply_state = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
/* name */ "",
ActorID::Nil(),
"NIL");
EXPECT_EQ(reply_state.events_by_task_size(), 2);

reply_state = SyncGetTaskEvents({},
/* job_id */ absl::nullopt,
/* limit */ -1,
/* exclude_driver */ false,
/* name */ "",
ActorID::Nil(),
"PENDING_NODE_ASSIGNMENT");
EXPECT_EQ(reply_state.events_by_task_size(), 0);
}

TEST_F(GcsTaskManagerTest, TestMarkTaskAttemptFailedIfNeeded) {
Expand All @@ -623,22 +673,22 @@ TEST_F(GcsTaskManagerTest, TestMarkTaskAttemptFailedIfNeeded) {
{
auto reply = SyncGetTaskEvents({tasks_running});
auto task_event = *(reply.events_by_task().begin());
EXPECT_EQ(task_event.state_updates().failed_ts(), 4);
EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), 4);
}

// Check task attempt failed event is not overriding failed tasks.
{
auto reply = SyncGetTaskEvents({tasks_failed});
auto task_event = *(reply.events_by_task().begin());
EXPECT_EQ(task_event.state_updates().failed_ts(), 3);
EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), 3);
}

// Check task attempt failed event is not overriding finished tasks.
{
auto reply = SyncGetTaskEvents({tasks_finished});
auto task_event = *(reply.events_by_task().begin());
EXPECT_FALSE(task_event.state_updates().has_failed_ts());
EXPECT_EQ(task_event.state_updates().finished_ts(), 2);
EXPECT_FALSE(task_event.state_updates().state_ts().contains(rpc::TaskStatus::FAILED));
EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FINISHED), 2);
}
}

Expand Down Expand Up @@ -690,7 +740,8 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) {
auto reply = SyncGetTaskEvents(tasks);
EXPECT_EQ(reply.events_by_task_size(), 10);
for (const auto &task_event : reply.events_by_task()) {
EXPECT_EQ(task_event.state_updates().failed_ts(), /* 5 ms to ns */ 5 * 1000 * 1000);
EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED),
/* 5 ms to ns */ 5 * 1000 * 1000);
EXPECT_TRUE(task_event.state_updates().has_error_info());
EXPECT_TRUE(task_event.state_updates().error_info().error_type() ==
rpc::ErrorType::WORKER_DIED);
Expand All @@ -706,8 +757,9 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) {
auto reply = SyncGetTaskEvents(tasks);
EXPECT_EQ(reply.events_by_task_size(), 10);
for (const auto &task_event : reply.events_by_task()) {
EXPECT_EQ(task_event.state_updates().finished_ts(), 2);
EXPECT_FALSE(task_event.state_updates().has_failed_ts());
EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FINISHED), 2);
EXPECT_FALSE(
task_event.state_updates().state_ts().contains(rpc::TaskStatus::FAILED));
}
}

Expand All @@ -717,7 +769,7 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) {
auto reply = SyncGetTaskEvents(tasks);
EXPECT_EQ(reply.events_by_task_size(), 10);
for (const auto &task_event : reply.events_by_task()) {
EXPECT_EQ(task_event.state_updates().failed_ts(), 3);
EXPECT_EQ(task_event.state_updates().state_ts().at(rpc::TaskStatus::FAILED), 3);
}
}

Expand All @@ -728,8 +780,10 @@ TEST_F(GcsTaskManagerTest, TestJobFinishesFailAllRunningTasks) {
auto reply = SyncGetTaskEvents(tasks);
EXPECT_EQ(reply.events_by_task_size(), 5);
for (const auto &task_event : reply.events_by_task()) {
EXPECT_FALSE(task_event.state_updates().has_failed_ts());
EXPECT_FALSE(task_event.state_updates().has_finished_ts());
EXPECT_FALSE(
task_event.state_updates().state_ts().contains(rpc::TaskStatus::FAILED));
EXPECT_FALSE(
task_event.state_updates().state_ts().contains(rpc::TaskStatus::FINISHED));
}
}
}
Expand Down
39 changes: 6 additions & 33 deletions src/ray/gcs/pb_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ inline bool IsTaskTerminated(const rpc::TaskEvents &task_event) {
}

const auto &state_updates = task_event.state_updates();
return state_updates.has_finished_ts() || state_updates.has_failed_ts();
return state_updates.state_ts().contains(rpc::TaskStatus::FINISHED) ||
state_updates.state_ts().contains(rpc::TaskStatus::FAILED);
}

inline size_t NumProfileEvents(const rpc::TaskEvents &task_event) {
Expand Down Expand Up @@ -308,7 +309,7 @@ inline bool IsTaskFinished(const rpc::TaskEvents &task_event) {
}

const auto &state_updates = task_event.state_updates();
return state_updates.has_finished_ts();
return state_updates.state_ts().contains(rpc::TaskStatus::FINISHED);
}

/// Fill the rpc::TaskStateUpdate with the timestamps according to the status change.
Expand All @@ -319,39 +320,11 @@ inline bool IsTaskFinished(const rpc::TaskEvents &task_event) {
inline void FillTaskStatusUpdateTime(const ray::rpc::TaskStatus &task_status,
int64_t timestamp,
ray::rpc::TaskStateUpdate *state_updates) {
switch (task_status) {
case rpc::TaskStatus::PENDING_ARGS_AVAIL: {
state_updates->set_pending_args_avail_ts(timestamp);
break;
}
case rpc::TaskStatus::SUBMITTED_TO_WORKER: {
state_updates->set_submitted_to_worker_ts(timestamp);
break;
}
case rpc::TaskStatus::PENDING_NODE_ASSIGNMENT: {
state_updates->set_pending_node_assignment_ts(timestamp);
break;
}
case rpc::TaskStatus::FINISHED: {
state_updates->set_finished_ts(timestamp);
break;
}
case rpc::TaskStatus::FAILED: {
state_updates->set_failed_ts(timestamp);
break;
}
case rpc::TaskStatus::RUNNING: {
state_updates->set_running_ts(timestamp);
break;
}
case rpc::TaskStatus::NIL: {
if (task_status == rpc::TaskStatus::NIL) {
// Not status change.
break;
}
default: {
UNREACHABLE;
}
return;
}
(*state_updates->mutable_state_ts())[task_status] = timestamp;
}

inline std::string FormatPlacementGroupLabelName(const std::string &pg_id) {
Expand Down
Loading

0 comments on commit 29fa5cd

Please sign in to comment.