From f0f52fa56f139be39d1f33c1598f0f5e06f9ab35 Mon Sep 17 00:00:00 2001 From: Jiajun Yao Date: Sat, 15 Jun 2024 23:47:18 -0700 Subject: [PATCH] [Core] Cancel lease requests before returning a PG bundle (#45919) Signed-off-by: Jiajun Yao --- python/ray/tests/test_gcs_fault_tolerance.py | 84 ++++++++++++ python/ray/tests/test_placement_group_5.py | 7 +- src/ray/core_worker/core_worker_process.cc | 1 + .../gcs_placement_group_scheduler.cc | 39 ++---- src/ray/gcs/gcs_server/gcs_server_main.cc | 1 + src/ray/raylet/local_task_manager.cc | 124 +++++++++-------- src/ray/raylet/local_task_manager.h | 30 ++-- src/ray/raylet/main.cc | 1 + src/ray/raylet/node_manager.cc | 26 +++- .../placement_group_resource_manager.cc | 2 +- .../raylet/scheduling/cluster_task_manager.cc | 129 +++++++----------- .../raylet/scheduling/cluster_task_manager.h | 10 ++ .../cluster_task_manager_interface.h | 11 ++ .../scheduling/cluster_task_manager_test.cc | 17 ++- .../scheduling/local_task_manager_interface.h | 34 ++--- src/ray/util/container_util.h | 40 ++++-- src/ray/util/tests/container_util_test.cc | 47 ++++--- 17 files changed, 366 insertions(+), 237 deletions(-) diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index afdf52311296..ca34239fd7ca 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -1,4 +1,5 @@ import sys +import asyncio import os import threading from time import sleep @@ -22,6 +23,8 @@ ) from ray.job_submission import JobSubmissionClient, JobStatus from ray._raylet import GcsClient +from ray._private.runtime_env.plugin import RuntimeEnvPlugin +from ray.util.state import list_placement_groups import psutil @@ -1213,6 +1216,87 @@ def spawn(self, name, namespace): raise ValueError(f"Unknown case: {case}") +MyPlugin = "MyPlugin" +MY_PLUGIN_CLASS_PATH = "ray.tests.test_gcs_fault_tolerance.HangPlugin" + + +class HangPlugin(RuntimeEnvPlugin): + name = MyPlugin + + async def create( + self, + uri, + runtime_env, + ctx, + logger, # noqa: F821 + ) -> float: + while True: + await asyncio.sleep(1) + + @staticmethod + def validate(runtime_env_dict: dict) -> str: + return 1 + + +@pytest.mark.parametrize( + "ray_start_regular_with_external_redis", + [ + generate_system_config_map( + gcs_rpc_server_reconnect_timeout_s=60, + testing_asio_delay_us="NodeManagerService.grpc_server.CancelResourceReserve=500000000:500000000", # noqa: E501 + ), + ], + indirect=True, +) +@pytest.mark.parametrize( + "set_runtime_env_plugins", + [ + '[{"class":"' + MY_PLUGIN_CLASS_PATH + '"}]', + ], + indirect=True, +) +def test_placement_group_removal_after_gcs_restarts( + set_runtime_env_plugins, ray_start_regular_with_external_redis +): + @ray.remote + def task(): + pass + + pg = ray.util.placement_group(bundles=[{"CPU": 1}]) + _ = task.options( + max_retries=0, + num_cpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + ), + runtime_env={ + MyPlugin: {"name": "f2"}, + "config": {"setup_timeout_seconds": -1}, + }, + ).remote() + + # The task should be popping worker + # TODO(jjyao) Use a more determinstic way to + # decide whether the task is popping worker + sleep(5) + + ray.util.remove_placement_group(pg) + # The PG is marked as REMOVED in redis but not removed yet from raylet + # due to the injected delay of CancelResourceReserve rpc + wait_for_condition(lambda: list_placement_groups()[0].state == "REMOVED") + + ray._private.worker._global_node.kill_gcs_server() + # After GCS restarts, it will try to remove the PG resources + # again via ReleaseUnusedBundles rpc + ray._private.worker._global_node.start_gcs_server() + + def verify_pg_resources_cleaned(): + r_keys = ray.available_resources().keys() + return all("group" not in k for k in r_keys) + + wait_for_condition(verify_pg_resources_cleaned, timeout=30) + + if __name__ == "__main__": import pytest diff --git a/python/ray/tests/test_placement_group_5.py b/python/ray/tests/test_placement_group_5.py index 9837b2ffce89..4af2894a99f3 100644 --- a/python/ray/tests/test_placement_group_5.py +++ b/python/ray/tests/test_placement_group_5.py @@ -470,10 +470,9 @@ async def create( ) -> float: await asyncio.sleep(PLUGIN_TIMEOUT) - -@staticmethod -def validate(runtime_env_dict: dict) -> str: - return 1 + @staticmethod + def validate(runtime_env_dict: dict) -> str: + return 1 @pytest.mark.parametrize( diff --git a/src/ray/core_worker/core_worker_process.cc b/src/ray/core_worker/core_worker_process.cc index 5991ab6ab0e1..95ab400f1689 100644 --- a/src/ray/core_worker/core_worker_process.cc +++ b/src/ray/core_worker/core_worker_process.cc @@ -233,6 +233,7 @@ void CoreWorkerProcessImpl::InitializeSystemConfig() { thread.join(); RayConfig::instance().initialize(promise.get_future().get()); + ray::asio::testing::init(); } void CoreWorkerProcessImpl::RunWorkerTaskExecutionLoop() { diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index 0e3425d909e8..4a40e27c8319 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -243,9 +243,9 @@ void GcsPlacementGroupScheduler::CancelResourceReserve( auto node_id = NodeID::FromBinary(node.value()->node_id()); if (max_retry == current_retry_cnt) { - RAY_LOG(INFO) << "Failed to cancel resource reserved for bundle because the max " - "retry count is reached. " - << bundle_spec->DebugString() << " at node " << node_id; + RAY_LOG(ERROR) << "Failed to cancel resource reserved for bundle because the max " + "retry count is reached. " + << bundle_spec->DebugString() << " at node " << node_id; return; } @@ -261,11 +261,10 @@ void GcsPlacementGroupScheduler::CancelResourceReserve( RAY_LOG(INFO) << "Finished cancelling the resource reserved for bundle: " << bundle_spec->DebugString() << " at node " << node_id; } else { - // We couldn't delete the pg resources either becuase it is in use - // or network issue. Retry. - RAY_LOG(INFO) << "Failed to cancel the resource reserved for bundle: " - << bundle_spec->DebugString() << " at node " << node_id - << ". Status: " << status; + // We couldn't delete the pg resources because of network issue. Retry. + RAY_LOG(WARNING) << "Failed to cancel the resource reserved for bundle: " + << bundle_spec->DebugString() << " at node " << node_id + << ". Status: " << status; execute_after( io_context_, [this, bundle_spec, node, max_retry, current_retry_cnt] { @@ -568,14 +567,10 @@ void GcsPlacementGroupScheduler::DestroyPlacementGroupPreparedBundleResources( for (const auto &iter : *(leasing_bundle_locations)) { auto &bundle_spec = iter.second.second; auto &node_id = iter.second.first; - CancelResourceReserve( - bundle_spec, - gcs_node_manager_.GetAliveNode(node_id), - // Retry 10 * worker registeration timeout to avoid race condition. - // See https://github.com/ray-project/ray/pull/42942 - // for more details. - /*max_retry*/ RayConfig::instance().worker_register_timeout_seconds() * 10, - /*num_retry*/ 0); + CancelResourceReserve(bundle_spec, + gcs_node_manager_.GetAliveNode(node_id), + /*max_retry*/ 5, + /*num_retry*/ 0); } } } @@ -594,14 +589,10 @@ void GcsPlacementGroupScheduler::DestroyPlacementGroupCommittedBundleResources( for (const auto &iter : *(committed_bundle_locations)) { auto &bundle_spec = iter.second.second; auto &node_id = iter.second.first; - CancelResourceReserve( - bundle_spec, - gcs_node_manager_.GetAliveNode(node_id), - // Retry 10 * worker registeration timeout to avoid race condition. - // See https://github.com/ray-project/ray/pull/42942 - // for more details. - /*max_retry*/ RayConfig::instance().worker_register_timeout_seconds() * 10, - /*num_retry*/ 0); + CancelResourceReserve(bundle_spec, + gcs_node_manager_.GetAliveNode(node_id), + /*max_retry*/ 5, + /*num_retry*/ 0); } committed_bundle_location_index_.Erase(placement_group_id); cluster_resource_scheduler_.GetClusterResourceManager() diff --git a/src/ray/gcs/gcs_server/gcs_server_main.cc b/src/ray/gcs/gcs_server/gcs_server_main.cc index 5bd85f900dc2..c58fbfbd8477 100644 --- a/src/ray/gcs/gcs_server/gcs_server_main.cc +++ b/src/ray/gcs/gcs_server/gcs_server_main.cc @@ -62,6 +62,7 @@ int main(int argc, char *argv[]) { gflags::ShutDownCommandLineFlags(); RayConfig::instance().initialize(config_list); + ray::asio::testing::init(); // IO Service for main loop. instrumented_io_context main_service; diff --git a/src/ray/raylet/local_task_manager.cc b/src/ray/raylet/local_task_manager.cc index f2161dc5c003..d499abacb205 100644 --- a/src/ray/raylet/local_task_manager.cc +++ b/src/ray/raylet/local_task_manager.cc @@ -546,21 +546,15 @@ bool LocalTaskManager::PoppedWorkerHandler( not_detached_with_owner_failed = true; } - const auto &required_resource = - task.GetTaskSpecification().GetRequiredResources().GetResourceMap(); - for (auto &entry : required_resource) { - if (!cluster_resource_scheduler_->GetLocalResourceManager().ResourcesExist( - scheduling::ResourceID(entry.first))) { - RAY_CHECK(task.GetTaskSpecification().PlacementGroupBundleId().first != - PlacementGroupID::Nil()); - RAY_LOG(DEBUG) << "The placement group: " - << task.GetTaskSpecification().PlacementGroupBundleId().first - << " was removed when poping workers for task: " << task_id - << ", will cancel the task."; - CancelTask( - task_id, - rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_PLACEMENT_GROUP_REMOVED); - canceled = true; + if (!canceled) { + const auto &required_resource = + task.GetTaskSpecification().GetRequiredResources().GetResourceMap(); + for (auto &entry : required_resource) { + // This is to make sure PG resource is not deleted during popping worker + // unless the lease request is cancelled. + RAY_CHECK(cluster_resource_scheduler_->GetLocalResourceManager().ResourcesExist( + scheduling::ResourceID(entry.first))) + << entry.first; } } @@ -855,7 +849,7 @@ void LocalTaskManager::ReleaseTaskArgs(const TaskID &task_id) { } namespace { -void ReplyCancelled(std::shared_ptr &work, +void ReplyCancelled(const std::shared_ptr &work, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message) { auto reply = work->reply; @@ -867,55 +861,67 @@ void ReplyCancelled(std::shared_ptr &work, } } // namespace -bool LocalTaskManager::CancelTask( - const TaskID &task_id, +bool LocalTaskManager::CancelTasks( + std::function &)> predicate, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message) { - for (auto shapes_it = tasks_to_dispatch_.begin(); shapes_it != tasks_to_dispatch_.end(); - shapes_it++) { - auto &work_queue = shapes_it->second; - for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { - const auto &task = (*work_it)->task; - if (task.GetTaskSpecification().TaskId() == task_id) { - RAY_LOG(DEBUG) << "Canceling task " << task_id << " from dispatch queue."; - ReplyCancelled(*work_it, failure_type, scheduling_failure_message); - if ((*work_it)->GetState() == internal::WorkStatus::WAITING_FOR_WORKER) { - // We've already acquired resources so we need to release them. - cluster_resource_scheduler_->GetLocalResourceManager().ReleaseWorkerResources( - (*work_it)->allocated_instances); - // Release pinned task args. - ReleaseTaskArgs(task_id); - } - if (!task.GetTaskSpecification().GetDependencies().empty()) { - task_dependency_manager_.RemoveTaskDependencies( - task.GetTaskSpecification().TaskId()); + bool tasks_cancelled = false; + + ray::erase_if>( + tasks_to_dispatch_, [&](const std::shared_ptr &work) { + if (predicate(work)) { + const TaskID task_id = work->task.GetTaskSpecification().TaskId(); + RAY_LOG(DEBUG) << "Canceling task " << task_id << " from dispatch queue."; + ReplyCancelled(work, failure_type, scheduling_failure_message); + if (work->GetState() == internal::WorkStatus::WAITING_FOR_WORKER) { + // We've already acquired resources so we need to release them. + cluster_resource_scheduler_->GetLocalResourceManager().ReleaseWorkerResources( + work->allocated_instances); + // Release pinned task args. + ReleaseTaskArgs(task_id); + } + if (!work->task.GetTaskSpecification().GetDependencies().empty()) { + task_dependency_manager_.RemoveTaskDependencies( + work->task.GetTaskSpecification().TaskId()); + } + RemoveFromRunningTasksIfExists(work->task); + work->SetStateCancelled(); + tasks_cancelled = true; + return true; + } else { + return false; } - RemoveFromRunningTasksIfExists(task); - (*work_it)->SetStateCancelled(); - work_queue.erase(work_it); - if (work_queue.empty()) { - tasks_to_dispatch_.erase(shapes_it); + }); + + ray::erase_if>( + waiting_task_queue_, [&](const std::shared_ptr &work) { + if (predicate(work)) { + ReplyCancelled(work, failure_type, scheduling_failure_message); + if (!work->task.GetTaskSpecification().GetDependencies().empty()) { + task_dependency_manager_.RemoveTaskDependencies( + work->task.GetTaskSpecification().TaskId()); + } + waiting_tasks_index_.erase(work->task.GetTaskSpecification().TaskId()); + tasks_cancelled = true; + return true; + } else { + return false; } - return true; - } - } - } + }); - auto iter = waiting_tasks_index_.find(task_id); - if (iter != waiting_tasks_index_.end()) { - const auto &task = (*iter->second)->task; - ReplyCancelled(*iter->second, failure_type, scheduling_failure_message); - if (!task.GetTaskSpecification().GetDependencies().empty()) { - task_dependency_manager_.RemoveTaskDependencies( - task.GetTaskSpecification().TaskId()); - } - waiting_task_queue_.erase(iter->second); - waiting_tasks_index_.erase(iter); - - return true; - } + return tasks_cancelled; +} - return false; +bool LocalTaskManager::CancelTask( + const TaskID &task_id, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) { + return CancelTasks( + [task_id](const std::shared_ptr &work) { + return work->task.GetTaskSpecification().TaskId() == task_id; + }, + failure_type, + scheduling_failure_message); } bool LocalTaskManager::AnyPendingTasksForResourceAcquisition( diff --git a/src/ray/raylet/local_task_manager.h b/src/ray/raylet/local_task_manager.h index b72861ce95ed..77468548cd12 100644 --- a/src/ray/raylet/local_task_manager.h +++ b/src/ray/raylet/local_task_manager.h @@ -111,17 +111,15 @@ class LocalTaskManager : public ILocalTaskManager { /// \param task: Output parameter. void TaskFinished(std::shared_ptr worker, RayTask *task); - /// Attempt to cancel an already queued task. + /// Attempt to cancel all queued tasks that match the predicate. /// - /// \param task_id: The id of the task to remove. - /// \param failure_type: The failure type. - /// - /// \return True if task was successfully removed. This function will return - /// false if the task is already running. - bool CancelTask(const TaskID &task_id, - rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type = - rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, - const std::string &scheduling_failure_message = "") override; + /// \param predicate: A function that returns true if a task needs to be cancelled. + /// \param failure_type: The reason for cancellation. + /// \param scheduling_failure_message: The reason message for cancellation. + /// \return True if any task was successfully cancelled. + bool CancelTasks(std::function &)> predicate, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) override; /// Return if any tasks are pending resource acquisition. /// @@ -203,6 +201,18 @@ class LocalTaskManager : public ILocalTaskManager { const rpc::Address &owner_address, const std::string &runtime_env_setup_error_message); + /// Attempt to cancel an already queued task. + /// + /// \param task_id: The id of the task to remove. + /// \param failure_type: The failure type. + /// + /// \return True if task was successfully removed. This function will return + /// false if the task is already running. + bool CancelTask(const TaskID &task_id, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type = + rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, + const std::string &scheduling_failure_message = ""); + /// Attempts to dispatch all tasks which are ready to run. A task /// will be dispatched if it is on `tasks_to_dispatch_` and there are still /// available resources on the node. diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 19d90124892a..4c77133c29d1 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -272,6 +272,7 @@ int main(int argc, char *argv[]) { RAY_CHECK_OK(status); RAY_CHECK(stored_raylet_config.has_value()); RayConfig::instance().initialize(stored_raylet_config.get()); + ray::asio::testing::init(); // Core worker tries to kill child processes when it exits. But they can't do // it perfectly: if the core worker is killed by SIGKILL, the child processes diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 64a4aa97b1c4..cbcf198313b8 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -683,6 +683,16 @@ void NodeManager::HandleReleaseUnusedBundles(rpc::ReleaseUnusedBundlesRequest re -1); } + // Cancel lease requests related to unused bundles + cluster_task_manager_->CancelTasks( + [&](const std::shared_ptr &work) { + const auto bundle_id = work->task.GetTaskSpecification().PlacementGroupBundleId(); + return !bundle_id.first.IsNil() && 0 == in_use_bundles.count(bundle_id); + }, + rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, + "The task is cancelled because it uses placement group bundles that are not " + "registered to GCS. It can happen upon GCS restart."); + // Kill all workers that are currently associated with the unused bundles. // NOTE: We can't traverse directly with `leased_workers_`, because `DestroyWorker` will // delete the element of `leased_workers_`. So we need to filter out @@ -1889,6 +1899,15 @@ void NodeManager::HandleCancelResourceReserve( RAY_LOG(DEBUG) << "Request to cancel reserved resource is received, " << bundle_spec.DebugString(); + // Cancel lease requests related to the placement group to be removed. + cluster_task_manager_->CancelTasks( + [&](const std::shared_ptr &work) { + const auto bundle_id = work->task.GetTaskSpecification().PlacementGroupBundleId(); + return bundle_id.first == bundle_spec.PlacementGroupId(); + }, + rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_PLACEMENT_GROUP_REMOVED, + ""); + // Kill all workers that are currently associated with the placement group. // NOTE: We can't traverse directly with `leased_workers_`, because `DestroyWorker` will // delete the element of `leased_workers_`. So we need to filter out @@ -1914,12 +1933,9 @@ void NodeManager::HandleCancelResourceReserve( DestroyWorker(worker, rpc::WorkerExitType::INTENDED_SYSTEM_EXIT, message); } - // Return bundle resources. If it fails to return a bundle, - // it will return none-ok status. They are transient state, - // and GCS should retry. - auto status = placement_group_resource_manager_->ReturnBundle(bundle_spec); + RAY_CHECK_OK(placement_group_resource_manager_->ReturnBundle(bundle_spec)); cluster_task_manager_->ScheduleAndDispatchTasks(); - send_reply_callback(status, nullptr, nullptr); + send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::HandleReturnWorker(rpc::ReturnWorkerRequest request, diff --git a/src/ray/raylet/placement_group_resource_manager.cc b/src/ray/raylet/placement_group_resource_manager.cc index 0f16cf766535..e0906c23885f 100644 --- a/src/ray/raylet/placement_group_resource_manager.cc +++ b/src/ray/raylet/placement_group_resource_manager.cc @@ -26,7 +26,7 @@ void PlacementGroupResourceManager::ReturnUnusedBundle( const std::unordered_set &in_use_bundles) { for (auto iter = bundle_spec_map_.begin(); iter != bundle_spec_map_.end();) { if (0 == in_use_bundles.count(iter->first)) { - RAY_CHECK(ReturnBundle(*iter->second).ok()); + RAY_CHECK_OK(ReturnBundle(*iter->second)); bundle_spec_map_.erase(iter++); } else { iter++; diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index c4e6cff7a08a..99b998dc14fe 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -77,54 +77,60 @@ void ReplyCancelled(const internal::Work &work, } } // namespace +bool ClusterTaskManager::CancelTasks( + std::function &)> predicate, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) { + bool tasks_cancelled = false; + + ray::erase_if>( + tasks_to_schedule_, [&](const std::shared_ptr &work) { + if (predicate(work)) { + RAY_LOG(DEBUG) << "Canceling task " + << work->task.GetTaskSpecification().TaskId() + << " from schedule queue."; + ReplyCancelled(*work, failure_type, scheduling_failure_message); + tasks_cancelled = true; + return true; + } else { + return false; + } + }); + + ray::erase_if>( + infeasible_tasks_, [&](const std::shared_ptr &work) { + if (predicate(work)) { + RAY_LOG(DEBUG) << "Canceling task " + << work->task.GetTaskSpecification().TaskId() + << " from infeasible queue."; + ReplyCancelled(*work, failure_type, scheduling_failure_message); + tasks_cancelled = true; + return true; + } else { + return false; + } + }); + + if (local_task_manager_->CancelTasks( + predicate, failure_type, scheduling_failure_message)) { + tasks_cancelled = true; + } + + return tasks_cancelled; +} + bool ClusterTaskManager::CancelAllTaskOwnedBy( const WorkerID &worker_id, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message) { // Only tasks and regular actors are canceled because their lifetime is // the same as the owner. - auto shapes_it = tasks_to_schedule_.begin(); - while (shapes_it != tasks_to_schedule_.end()) { - auto &work_queue = shapes_it->second; - auto work_it = work_queue.begin(); - while (work_it != work_queue.end()) { - const auto &task = (*work_it)->task; - const auto &spec = task.GetTaskSpecification(); - if (!spec.IsDetachedActor() && spec.CallerWorkerId() == worker_id) { - ReplyCancelled(*(*work_it), failure_type, scheduling_failure_message); - work_it = work_queue.erase(work_it); - } else { - ++work_it; - } - } - if (work_queue.empty()) { - tasks_to_schedule_.erase(shapes_it++); - } else { - ++shapes_it; - } - } + auto predicate = [worker_id](const std::shared_ptr &work) { + return !work->task.GetTaskSpecification().IsDetachedActor() && + work->task.GetTaskSpecification().CallerWorkerId() == worker_id; + }; - shapes_it = infeasible_tasks_.begin(); - while (shapes_it != infeasible_tasks_.end()) { - auto &work_queue = shapes_it->second; - auto work_it = work_queue.begin(); - while (work_it != work_queue.end()) { - const auto &task = (*work_it)->task; - const auto &spec = task.GetTaskSpecification(); - if (!spec.IsDetachedActor() && spec.CallerWorkerId() == worker_id) { - ReplyCancelled(*(*work_it), failure_type, scheduling_failure_message); - work_it = work_queue.erase(work_it); - } else { - ++work_it; - } - } - if (work_queue.empty()) { - infeasible_tasks_.erase(shapes_it++); - } else { - ++shapes_it; - } - } - return true; + return CancelTasks(predicate, failure_type, scheduling_failure_message); } void ClusterTaskManager::ScheduleAndDispatchTasks() { @@ -268,44 +274,11 @@ bool ClusterTaskManager::CancelTask( const TaskID &task_id, rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, const std::string &scheduling_failure_message) { - // TODO(sang): There are lots of repetitive code around task backlogs. We should - // refactor them. - for (auto shapes_it = tasks_to_schedule_.begin(); shapes_it != tasks_to_schedule_.end(); - shapes_it++) { - auto &work_queue = shapes_it->second; - for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { - const auto &task = (*work_it)->task; - if (task.GetTaskSpecification().TaskId() == task_id) { - RAY_LOG(DEBUG) << "Canceling task " << task_id << " from schedule queue."; - ReplyCancelled(*(*work_it), failure_type, scheduling_failure_message); - work_queue.erase(work_it); - if (work_queue.empty()) { - tasks_to_schedule_.erase(shapes_it); - } - return true; - } - } - } - - for (auto shapes_it = infeasible_tasks_.begin(); shapes_it != infeasible_tasks_.end(); - shapes_it++) { - auto &work_queue = shapes_it->second; - for (auto work_it = work_queue.begin(); work_it != work_queue.end(); work_it++) { - const auto &task = (*work_it)->task; - if (task.GetTaskSpecification().TaskId() == task_id) { - RAY_LOG(DEBUG) << "Canceling task " << task_id << " from infeasible queue."; - ReplyCancelled(*(*work_it), failure_type, scheduling_failure_message); - work_queue.erase(work_it); - if (work_queue.empty()) { - infeasible_tasks_.erase(shapes_it); - } - return true; - } - } - } + auto predicate = [task_id](const std::shared_ptr &work) { + return work->task.GetTaskSpecification().TaskId() == task_id; + }; - return local_task_manager_->CancelTask( - task_id, failure_type, scheduling_failure_message); + return CancelTasks(predicate, failure_type, scheduling_failure_message); } void ClusterTaskManager::FillResourceUsage(rpc::ResourcesData &data) { diff --git a/src/ray/raylet/scheduling/cluster_task_manager.h b/src/ray/raylet/scheduling/cluster_task_manager.h index a3363365bb10..058c40f97fcf 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.h +++ b/src/ray/raylet/scheduling/cluster_task_manager.h @@ -93,6 +93,16 @@ class ClusterTaskManager : public ClusterTaskManagerInterface { rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, const std::string &scheduling_failure_message = "") override; + /// Attempt to cancel all queued tasks that match the predicate. + /// + /// \param predicate: A function that returns true if a task needs to be cancelled. + /// \param failure_type: The reason for cancellation. + /// \param scheduling_failure_message: The reason message for cancellation. + /// \return True if any task was successfully cancelled. + bool CancelTasks(std::function &)> predicate, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) override; + /// Populate the relevant parts of the heartbeat table. This is intended for /// sending resource usage of raylet to gcs. In particular, this should fill in /// resource_load and resource_load_by_shape. diff --git a/src/ray/raylet/scheduling/cluster_task_manager_interface.h b/src/ray/raylet/scheduling/cluster_task_manager_interface.h index 8ae664479924..0e2bdbe08bb6 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_interface.h +++ b/src/ray/raylet/scheduling/cluster_task_manager_interface.h @@ -54,6 +54,17 @@ class ClusterTaskManagerInterface { rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, const std::string &scheduling_failure_message = "") = 0; + /// Attempt to cancel all queued tasks that match the predicate. + /// + /// \param predicate: A function that returns true if a task needs to be cancelled. + /// \param failure_type: The reason for cancellation. + /// \param scheduling_failure_message: The reason message for cancellation. + /// \return True if any task was successfully cancelled. + virtual bool CancelTasks( + std::function &)> predicate, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) = 0; + /// Queue task and schedule. This hanppens when processing the worker lease request. /// /// \param task: The incoming task to be queued and scheduled. diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 2fe7eec7452a..6a03a6036f61 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -1218,7 +1218,6 @@ TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { callback_called = false; reply.Clear(); ASSERT_FALSE(task_manager_.CancelTask(task2.GetTaskSpecification().TaskId())); - // Task2 will not execute. ASSERT_FALSE(reply.canceled()); ASSERT_FALSE(callback_called); ASSERT_EQ(pool_.workers.size(), 0); @@ -1229,6 +1228,22 @@ TEST_F(ClusterTaskManagerTest, TaskCancellationTest) { ASSERT_EQ(finished_task.GetTaskSpecification().TaskId(), task2.GetTaskSpecification().TaskId()); + RayTask task3 = CreateTask({{ray::kCPU_ResourceLabel, 2}}); + rpc::RequestWorkerLeaseReply reply3; + RayTask task4 = CreateTask({{ray::kCPU_ResourceLabel, 200}}); + rpc::RequestWorkerLeaseReply reply4; + // Task 3 should be popping worker + task_manager_.QueueAndScheduleTask(task3, false, false, &reply3, callback); + // Task 4 is infeasible + task_manager_.QueueAndScheduleTask(task4, false, false, &reply4, callback); + pool_.TriggerCallbacks(); + ASSERT_TRUE(task_manager_.CancelTasks( + [](const std::shared_ptr &work) { return true; }, + rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, + "")); + ASSERT_TRUE(reply3.canceled()); + ASSERT_TRUE(reply4.canceled()); + AssertNoLeaks(); } diff --git a/src/ray/raylet/scheduling/local_task_manager_interface.h b/src/ray/raylet/scheduling/local_task_manager_interface.h index 03f3a8b15a60..8bdce254a418 100644 --- a/src/ray/raylet/scheduling/local_task_manager_interface.h +++ b/src/ray/raylet/scheduling/local_task_manager_interface.h @@ -37,18 +37,16 @@ class ILocalTaskManager { // Schedule and dispatch tasks. virtual void ScheduleAndDispatchTasks() = 0; - /// Attempt to cancel an already queued task. + /// Attempt to cancel all queued tasks that match the predicate. /// - /// \param task_id: The id of the task to remove. - /// \param failure_type: The failure type. - /// - /// \return True if task was successfully removed. This function will return - /// false if the task is already running. - virtual bool CancelTask( - const TaskID &task_id, - rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type = - rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, - const std::string &scheduling_failure_message = "") = 0; + /// \param predicate: A function that returns true if a task needs to be cancelled. + /// \param failure_type: The reason for cancellation. + /// \param scheduling_failure_message: The reason message for cancellation. + /// \return True if any task was successfully cancelled. + virtual bool CancelTasks( + std::function &)> predicate, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) = 0; virtual const absl::flat_hash_map>> @@ -88,17 +86,9 @@ class NoopLocalTaskManager : public ILocalTaskManager { // Schedule and dispatch tasks. void ScheduleAndDispatchTasks() override {} - /// Attempt to cancel an already queued task. - /// - /// \param task_id: The id of the task to remove. - /// \param failure_type: The failure type. - /// - /// \return True if task was successfully removed. This function will return - /// false if the task is already running. - bool CancelTask(const TaskID &task_id, - rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type = - rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_INTENDED, - const std::string &scheduling_failure_message = "") override { + bool CancelTasks(std::function &)> predicate, + rpc::RequestWorkerLeaseReply::SchedulingFailureType failure_type, + const std::string &scheduling_failure_message) override { return false; } diff --git a/src/ray/util/container_util.h b/src/ray/util/container_util.h index 6a6bc671e529..6a363dc09d0f 100644 --- a/src/ray/util/container_util.h +++ b/src/ray/util/container_util.h @@ -97,19 +97,35 @@ typename C::mapped_type &map_find_or_die(C &c, const typename C::key_type &k) { map_find_or_die(const_cast(c), k)); } -/// Remove elements whole matcher returns true against the element. -/// -/// @param matcher the matcher function to be applied to each elements -/// @param container the container of the elements -template -void remove_elements(std::function matcher, std::deque &container) { - auto itr = container.begin(); - while (itr != container.end()) { - if (matcher(*itr)) { - itr = container.erase(itr); +// This is guaranteed that predicate is applied to each element exactly once, +// so it can have side effect. +template +void erase_if(absl::flat_hash_map> &map, + std::function predicate) { + for (auto map_it = map.begin(); map_it != map.end();) { + auto &queue = map_it->second; + for (auto queue_it = queue.begin(); queue_it != queue.end();) { + if (predicate(*queue_it)) { + queue_it = queue.erase(queue_it); + } else { + ++queue_it; + } + } + if (queue.empty()) { + map.erase(map_it++); + } else { + ++map_it; } - if (itr != container.end()) { - itr++; + } +} + +template +void erase_if(std::list &list, std::function predicate) { + for (auto list_it = list.begin(); list_it != list.end();) { + if (predicate(*list_it)) { + list_it = list.erase(list_it); + } else { + ++list_it; } } } diff --git a/src/ray/util/tests/container_util_test.cc b/src/ray/util/tests/container_util_test.cc index d5ba8a7aa7e4..0e404efd2f8b 100644 --- a/src/ray/util/tests/container_util_test.cc +++ b/src/ray/util/tests/container_util_test.cc @@ -36,31 +36,36 @@ TEST(ContainerUtilTest, TestMapFindOrDie) { } } -TEST(ContainerUtilTest, RemoveElementsLastElement) { - std::deque queue{1, 2, 3, 4}; - std::function even = [](int value) { return value % 2 == 0; }; - remove_elements(even, queue); - - std::deque expected{1, 3}; - ASSERT_EQ(queue, expected); -} +TEST(ContainerUtilTest, TestEraseIf) { + { + std::list list{1, 2, 3, 4}; + ray::erase_if(list, [](const int &value) { return value % 2 == 0; }); + ASSERT_EQ(list, (std::list{1, 3})); + } -TEST(ContainerUtilTest, RemoveElementsExcludeLastElement) { - std::deque queue{1, 2, 3}; - std::function even = [](int value) { return value % 2 == 0; }; - remove_elements(even, queue); + { + std::list list{1, 2, 3}; + ray::erase_if(list, [](const int &value) { return value % 2 == 0; }); + ASSERT_EQ(list, (std::list{1, 3})); + } - std::deque expected{1, 3}; - ASSERT_EQ(queue, expected); -} + { + std::list list{}; + ray::erase_if(list, [](const int &value) { return value % 2 == 0; }); + ASSERT_EQ(list, (std::list{})); + } -TEST(ContainerUtilTest, RemoveElementsEmptyContainer) { - std::deque queue{}; - std::function even = [](int value) { return value % 2 == 0; }; - remove_elements(even, queue); + { + absl::flat_hash_map> map; + map[1] = std::deque{1, 3}; + map[2] = std::deque{2, 4}; + map[3] = std::deque{5, 6}; + ray::erase_if(map, [](const int &value) { return value % 2 == 0; }); - std::deque expected{}; - ASSERT_EQ(queue, expected); + ASSERT_EQ(map.size(), 2); + ASSERT_EQ(map[1], (std::deque{1, 3})); + ASSERT_EQ(map[3], (std::deque{5})); + } } } // namespace ray