Skip to content

Commit

Permalink
Enabling the cancellation of non-actor tasks in a worker's queue 2 (r…
Browse files Browse the repository at this point in the history
…ay-project#13244)

* wrote code to enable cancellation of queued non-actor tasks

* minor changes

* bug fixes

* added comments

* rev1

* linting

* making ActorSchedulingQueue::CancelTaskIfFound raise a fatal error

* bug fix

* added two unit tests

* linting

* iterating through pending_normal_tasks starting from end

* fixup! iterating through pending_normal_tasks starting from end

* fixup! fixup! iterating through pending_normal_tasks starting from end

* post merge fixes

* added debugging instructions, pulled Accept() out of guarded loop

* removed debugging instructions, linting

* first commit

* lint

* lint

* added hack to avoid race condition in test stress

* moved hack

* fix test cancel

* removed hack (hopefully no longer needed)

* Revert "removed hack (hopefully no longer needed)"

This reverts commit 99d0e7c.

* added sleep in mock_worker.cc

* sleep function fixup to work on windows

* sleep in test_fast both for force=true and force=false

* linting

Co-authored-by: Ian <ian.rodney@gmail.com>
  • Loading branch information
Gabriele Oliaro and ijrsvt authored Feb 3, 2021
1 parent 875ea3f commit 7931045
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 19 deletions.
9 changes: 7 additions & 2 deletions python/ray/tests/test_cancel.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def infinite_sleep(y):
sleep_or_no = [random.randint(0, 1) for _ in range(100)]
tasks = [infinite_sleep.remote(i) for i in sleep_or_no]
cancelled = set()

# Randomly kill queued tasks (infinitely sleeping or not).
for t in tasks:
if random.random() > 0.5:
ray.cancel(t, force=use_force)
Expand All @@ -186,10 +188,13 @@ def infinite_sleep(y):
for done in cancelled:
with pytest.raises(valid_exceptions(use_force)):
ray.get(done, timeout=120)

# Kill all infinitely sleeping tasks (queued or not).
for indx, t in enumerate(tasks):
if sleep_or_no[indx]:
ray.cancel(t, force=use_force)
cancelled.add(t)
for indx, t in enumerate(tasks):
if t in cancelled:
with pytest.raises(valid_exceptions(use_force)):
ray.get(t, timeout=120)
Expand All @@ -213,8 +218,8 @@ def fast(y):
# between a worker receiving a task and the worker executing
# that task (specifically the python execution), Cancellation
# can fail.
if not use_force:
time.sleep(0.1)

time.sleep(0.1)
ray.cancel(x, force=use_force)
ids.append(x)

Expand Down
15 changes: 12 additions & 3 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ void CoreWorker::InternalHeartbeat(const boost::system::error_code &error) {
}

absl::MutexLock lock(&mutex_);

while (!to_resubmit_.empty() && current_time_ms() > to_resubmit_.front().first) {
auto &spec = to_resubmit_.front().second;
if (spec.IsActorTask()) {
Expand Down Expand Up @@ -2266,12 +2267,17 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request,
rpc::SendReplyCallback send_reply_callback) {
absl::MutexLock lock(&mutex_);
TaskID task_id = TaskID::FromBinary(request.intended_task_id());
bool success = main_thread_task_id_ == task_id;
bool requested_task_running = main_thread_task_id_ == task_id;
bool success = requested_task_running;

// Try non-force kill
if (success && !request.force_kill()) {
if (requested_task_running && !request.force_kill()) {
RAY_LOG(INFO) << "Interrupting a running task " << main_thread_task_id_;
success = options_.kill_main();
} else if (!requested_task_running) {
// If the task is not currently running, check if it is in the worker's queue of
// normal tasks, and remove it if found.
success = direct_task_receiver_->CancelQueuedNormalTask(task_id);
}
if (request.recursive()) {
auto recursive_cancel = CancelChildren(task_id, request.force_kill());
Expand All @@ -2280,11 +2286,14 @@ void CoreWorker::HandleCancelTask(const rpc::CancelTaskRequest &request,
}
}

// TODO: fix race condition to avoid using this hack
requested_task_running = main_thread_task_id_ == task_id;

reply->set_attempt_succeeded(success);
send_reply_callback(Status::OK(), nullptr, nullptr);

// Do force kill after reply callback sent
if (success && request.force_kill()) {
if (requested_task_running && request.force_kill()) {
RAY_LOG(INFO) << "Force killing a worker running " << main_thread_task_id_;
Disconnect();
if (options_.enable_logging) {
Expand Down
42 changes: 42 additions & 0 deletions src/ray/core_worker/test/core_worker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,48 @@ TEST_F(SingleNodeTest, TestNormalTaskLocal) {
TestNormalTask(resources);
}

TEST_F(SingleNodeTest, TestCancelTasks) {
auto &driver = CoreWorkerProcess::GetCoreWorker();

// Create two functions, each implementing a while(true) loop.
RayFunction func1(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
"WhileTrueLoop", "", "", ""));
RayFunction func2(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython(
"WhileTrueLoop", "", "", ""));
// Return IDs for the two functions that implement while(true) loops.
std::vector<ObjectID> return_ids1;
std::vector<ObjectID> return_ids2;

// Create default args and options needed to submit the tasks that encapsulate func1 and
// func2.
std::vector<std::unique_ptr<TaskArg>> args;
TaskOptions options;

// Submit func1. The function should start looping forever.
driver.SubmitTask(func1, args, options, &return_ids1, /*max_retries=*/0,
std::make_pair(PlacementGroupID::Nil(), -1), true,
/*debugger_breakpoint=*/"");
ASSERT_EQ(return_ids1.size(), 1);

// Submit func2. The function should be queued at the worker indefinitely.
driver.SubmitTask(func2, args, options, &return_ids2, /*max_retries=*/0,
std::make_pair(PlacementGroupID::Nil(), -1), true,
/*debugger_breakpoint=*/"");
ASSERT_EQ(return_ids2.size(), 1);

// Cancel func2 by removing it from the worker's queue
RAY_CHECK_OK(driver.CancelTask(return_ids2[0], true, false));

// Cancel func1, which is currently running.
RAY_CHECK_OK(driver.CancelTask(return_ids1[0], true, false));

// TestNormalTask will get stuck unless both func1 and func2 have been cancelled. Thus,
// if TestNormalTask succeeds, we know that func2 must have been removed from the
// worker's queue.
std::unordered_map<std::string, double> resources;
TestNormalTask(resources);
}

TEST_F(TwoNodeTest, TestNormalTaskCrossNodes) {
std::unordered_map<std::string, double> resources;
resources.emplace("resource1", 1);
Expand Down
11 changes: 11 additions & 0 deletions src/ray/core_worker/test/mock_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class MockWorker {
} else if ("MergeInputArgsAsOutput" == typed_descriptor->ModuleName()) {
// Merge input args and write the merged content to each of return ids
return MergeInputArgsAsOutput(args, return_ids, results);
} else if ("WhileTrueLoop" == typed_descriptor->ModuleName()) {
return WhileTrueLoop(args, return_ids, results);
} else {
return Status::TypeError("Unknown function descriptor: " +
typed_descriptor->ModuleName());
Expand Down Expand Up @@ -128,6 +130,15 @@ class MockWorker {
return Status::OK();
}

Status WhileTrueLoop(const std::vector<std::shared_ptr<RayObject>> &args,
const std::vector<ObjectID> &return_ids,
std::vector<std::shared_ptr<RayObject>> *results) {
while (1) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
return Status::OK();
}

int64_t prev_seq_no_ = 0;
};

Expand Down
27 changes: 23 additions & 4 deletions src/ray/core_worker/test/scheduling_queue_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ TEST(SchedulingQueueTest, TestWaitForObjects) {
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(1, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj1}));
queue.Add(2, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj2}));
queue.Add(3, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj3}));
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
queue.Add(2, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj2}));
queue.Add(3, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj3}));
ASSERT_EQ(n_ok, 1);

waiter.Complete(0);
Expand All @@ -92,7 +92,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) {
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue.Add(0, -1, fn_ok, fn_rej);
queue.Add(1, -1, fn_ok, fn_rej, ObjectIdsToRefs({obj1}));
queue.Add(1, -1, fn_ok, fn_rej, TaskID::Nil(), ObjectIdsToRefs({obj1}));
ASSERT_EQ(n_ok, 1);
io_service.run();
ASSERT_EQ(n_rej, 0);
Expand Down Expand Up @@ -158,6 +158,25 @@ TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) {
ASSERT_EQ(n_rej, 2);
}

TEST(SchedulingQueueTest, TestCancelQueuedTask) {
NormalSchedulingQueue *queue = new NormalSchedulingQueue();
ASSERT_TRUE(queue->TaskQueueEmpty());
int n_ok = 0;
int n_rej = 0;
auto fn_ok = [&n_ok]() { n_ok++; };
auto fn_rej = [&n_rej]() { n_rej++; };
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
queue->Add(-1, -1, fn_ok, fn_rej);
ASSERT_TRUE(queue->CancelTaskIfFound(TaskID::Nil()));
ASSERT_FALSE(queue->TaskQueueEmpty());
queue->ScheduleRequests();
ASSERT_EQ(n_ok, 4);
ASSERT_EQ(n_rej, 0);
}

} // namespace ray

int main(int argc, char **argv) {
Expand Down
10 changes: 8 additions & 2 deletions src/ray/core_worker/transport/direct_actor_transport.cc
Original file line number Diff line number Diff line change
Expand Up @@ -482,12 +482,12 @@ void CoreWorkerDirectTaskReceiver::HandleTask(
// TODO(swang): Remove this with legacy raylet code.
dependencies.pop_back();
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
accept_callback, reject_callback, dependencies);
accept_callback, reject_callback, task_spec.TaskId(), dependencies);
} else {
// Add the normal task's callbacks to the non-actor scheduling queue.
normal_scheduling_queue_->Add(request.sequence_number(),
request.client_processed_up_to(), accept_callback,
reject_callback, dependencies);
reject_callback, task_spec.TaskId(), dependencies);
}
}

Expand All @@ -501,4 +501,10 @@ void CoreWorkerDirectTaskReceiver::RunNormalTasksFromQueue() {
normal_scheduling_queue_->ScheduleRequests();
}

bool CoreWorkerDirectTaskReceiver::CancelQueuedNormalTask(TaskID task_id) {
// Look up the task to be canceled in the queue of normal tasks. If it is found and
// removed successfully, return true.
return normal_scheduling_queue_->CancelTaskIfFound(task_id);
}

} // namespace ray
56 changes: 48 additions & 8 deletions src/ray/core_worker/transport/direct_actor_transport.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,19 +254,23 @@ class InboundRequest {
public:
InboundRequest(){};
InboundRequest(std::function<void()> accept_callback,
std::function<void()> reject_callback, bool has_dependencies)
std::function<void()> reject_callback, TaskID task_id,
bool has_dependencies)
: accept_callback_(accept_callback),
reject_callback_(reject_callback),
task_id(task_id),
has_pending_dependencies_(has_dependencies) {}

void Accept() { accept_callback_(); }
void Cancel() { reject_callback_(); }
bool CanExecute() const { return !has_pending_dependencies_; }
ray::TaskID TaskID() const { return task_id; }
void MarkDependenciesSatisfied() { has_pending_dependencies_ = false; }

private:
std::function<void()> accept_callback_;
std::function<void()> reject_callback_;
ray::TaskID task_id;
bool has_pending_dependencies_;
};

Expand Down Expand Up @@ -346,10 +350,11 @@ class SchedulingQueue {
public:
virtual void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request,
std::function<void()> reject_request,
std::function<void()> reject_request, TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) = 0;
virtual void ScheduleRequests() = 0;
virtual bool TaskQueueEmpty() const = 0;
virtual bool CancelTaskIfFound(TaskID task_id) = 0;
virtual ~SchedulingQueue(){};
};

Expand All @@ -371,6 +376,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
/// Add a new actor task's callbacks to the worker queue.
void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request, std::function<void()> reject_request,
TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) {
// A seq_no of -1 means no ordering constraint. Actor tasks must be executed in order.
RAY_CHECK(seq_no != -1);
Expand All @@ -383,7 +389,7 @@ class ActorSchedulingQueue : public SchedulingQueue {
}
RAY_LOG(DEBUG) << "Enqueue " << seq_no << " cur seqno " << next_seq_no_;
pending_actor_tasks_[seq_no] =
InboundRequest(accept_request, reject_request, dependencies.size() > 0);
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0);
if (dependencies.size() > 0) {
waiter_.Wait(dependencies, [seq_no, this]() {
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
Expand All @@ -397,6 +403,15 @@ class ActorSchedulingQueue : public SchedulingQueue {
ScheduleRequests();
}

// We don't allow the cancellation of actor tasks, so invoking CancelTaskIfFound results
// in a fatal error.
bool CancelTaskIfFound(TaskID task_id) {
RAY_CHECK(false) << "Cannot cancel actor tasks";
// The return instruction will never be executed, but we need to include it
// nonetheless because this is a non-void function.
return false;
}

/// Schedules as many requests as possible in sequence.
void ScheduleRequests() {
// Only call SetMaxActorConcurrency to configure threadpool size when the
Expand Down Expand Up @@ -520,22 +535,45 @@ class NormalSchedulingQueue : public SchedulingQueue {
/// Add a new task's callbacks to the worker queue.
void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request, std::function<void()> reject_request,
TaskID task_id = TaskID::Nil(),
const std::vector<rpc::ObjectReference> &dependencies = {}) {
absl::MutexLock lock(&mu_);
// Normal tasks should not have ordering constraints.
RAY_CHECK(seq_no == -1);
// Create a InboundRequest object for the new task, and add it to the queue.
pending_normal_tasks_.push_back(
InboundRequest(accept_request, reject_request, dependencies.size() > 0));
InboundRequest(accept_request, reject_request, task_id, dependencies.size() > 0));
}

// Search for an InboundRequest associated with the task that we are trying to cancel.
// If found, remove the InboundRequest from the queue and return true. Otherwise, return
// false.
bool CancelTaskIfFound(TaskID task_id) {
absl::MutexLock lock(&mu_);
for (std::deque<InboundRequest>::reverse_iterator it = pending_normal_tasks_.rbegin();
it != pending_normal_tasks_.rend(); ++it) {
if (it->TaskID() == task_id) {
pending_normal_tasks_.erase(std::next(it).base());
return true;
}
}
return false;
}

/// Schedules as many requests as possible in sequence.
void ScheduleRequests() {
absl::MutexLock lock(&mu_);
while (!pending_normal_tasks_.empty()) {
auto &head = pending_normal_tasks_.front();
while (true) {
InboundRequest head;
{
absl::MutexLock lock(&mu_);
if (!pending_normal_tasks_.empty()) {
head = pending_normal_tasks_.front();
pending_normal_tasks_.pop_front();
} else {
return;
}
}
head.Accept();
pending_normal_tasks_.pop_front();
}
}

Expand Down Expand Up @@ -583,6 +621,8 @@ class CoreWorkerDirectTaskReceiver {
/// Pop tasks from the queue and execute them sequentially
void RunNormalTasksFromQueue();

bool CancelQueuedNormalTask(TaskID task_id);

private:
// Worker context.
WorkerContext &worker_context_;
Expand Down

0 comments on commit 7931045

Please sign in to comment.