Skip to content

[xray] Fix bug when counting a task's lineage size #2600

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/ray/raylet/lineage_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@ void LineageCache::AddReadyTask(const Task &task) {
}
}

uint64_t LineageCache::CountUnsubscribedLineage(const TaskID &task_id) const {
uint64_t LineageCache::CountUnsubscribedLineage(const TaskID &task_id,
std::unordered_set<TaskID> &seen) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think you can maybe default initialize here so you don't have to pass it in on the root call const std::unordered_set<TaskID> &seen = std::unordered_set<TaskID>() (but is potentially uglier).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer the current way.

if (seen.count(task_id) == 1) {
return 0;
}
seen.insert(task_id);
if (subscribed_tasks_.count(task_id) == 1) {
return 0;
}
Expand All @@ -221,7 +226,7 @@ uint64_t LineageCache::CountUnsubscribedLineage(const TaskID &task_id) const {
}
uint64_t cnt = 1;
for (const auto &parent_id : entry->GetParentTaskIds()) {
cnt += CountUnsubscribedLineage(parent_id);
cnt += CountUnsubscribedLineage(parent_id, seen);
}
return cnt;
}
Expand Down Expand Up @@ -249,7 +254,9 @@ void LineageCache::RemoveWaitingTask(const TaskID &task_id) {
// NOTE(swang): The number of entries in the uncommitted lineage also
// includes local tasks that haven't been committed yet, not just remote
// tasks, so this is an overestimate.
if (CountUnsubscribedLineage(task_id) > max_lineage_size_) {
std::unordered_set<TaskID> seen;
auto count = CountUnsubscribedLineage(task_id, seen);
if (count > max_lineage_size_) {
// Since this task was in state WAITING, check that we were not
// already subscribed to the task.
RAY_CHECK(SubscribeTask(task_id));
Expand Down
11 changes: 9 additions & 2 deletions src/ray/raylet/lineage_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,15 @@ class LineageCache {
/// Unsubscribe from notifications for a task. Returns whether the operation
/// was successful (whether we were subscribed).
bool UnsubscribeTask(const TaskID &task_id);
/// Count the size of unsubscribed and uncommitted lineage
uint64_t CountUnsubscribedLineage(const TaskID &task_id) const;
/// Count the size of unsubscribed and uncommitted lineage of the given task
/// excluding the values that have already been visited.
///
/// \param task_id The task whose lineage should be counted.
/// \param seen This set contains the keys of lineage entries counted so far,
/// so that we don't revisit those nodes.
/// \void The number of tasks that were counted.
uint64_t CountUnsubscribedLineage(const TaskID &task_id,
std::unordered_set<TaskID> &seen) const;

/// The client ID, used to request notifications for specific tasks.
/// TODO(swang): Move the ClientID into the generic Table implementation.
Expand Down
6 changes: 0 additions & 6 deletions test/stress_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def f(x):
assert ray.services.all_processes_alive()


@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") == "1",
reason="This test does not work with xray yet.")
def test_dependencies(ray_start_combination):
@ray.remote
def f(x):
Expand All @@ -81,9 +78,6 @@ def g(*xs):
assert ray.services.all_processes_alive()


@pytest.mark.skipif(
os.environ.get("RAY_USE_XRAY") == "1",
reason="This test does not work with xray yet.")
def test_submitting_many_tasks(ray_start_regular):
@ray.remote
def f(x):
Expand Down