Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ namespace raylet {

// A helper function to print the leased workers.
std::string LeasedWorkersSring(
const std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers) {
const std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>>
&leased_workers) {
std::stringstream buffer;
buffer << " @leased_workers: (";
for (const auto &pair : leased_workers) {
Expand All @@ -117,7 +118,8 @@ std::string LeasedWorkersSring(
}

// A helper function to print the workers in worker_pool_.
std::string WorkerPoolString(const std::vector<std::shared_ptr<Worker>> &worker_pool) {
std::string WorkerPoolString(
const std::vector<std::shared_ptr<WorkerInterface>> &worker_pool) {
std::stringstream buffer;
buffer << " @worker_pool: (";
for (const auto &worker : worker_pool) {
Expand All @@ -128,7 +130,7 @@ std::string WorkerPoolString(const std::vector<std::shared_ptr<Worker>> &worker_
}

// Helper function to print the worker's owner worker and and node owner.
std::string WorkerOwnerString(std::shared_ptr<Worker> &worker) {
std::string WorkerOwnerString(std::shared_ptr<WorkerInterface> &worker) {
std::stringstream buffer;
const auto owner_worker_id =
WorkerID::FromBinary(worker->GetOwnerAddress().worker_id());
Expand Down Expand Up @@ -320,7 +322,7 @@ ray::Status NodeManager::RegisterGcs() {
return ray::Status::OK();
}

void NodeManager::KillWorker(std::shared_ptr<Worker> worker) {
void NodeManager::KillWorker(std::shared_ptr<WorkerInterface> worker) {
#ifdef _WIN32
// TODO(mehrdadn): implement graceful process termination mechanism
#else
Expand Down Expand Up @@ -1072,7 +1074,7 @@ void NodeManager::DispatchTasks(

// Try to get an idle worker to execute this task. If nullptr, there
// aren't any available workers so we can't assign the task.
std::shared_ptr<Worker> worker =
std::shared_ptr<WorkerInterface> worker =
worker_pool_.PopWorker(task.GetTaskSpecification());
if (worker != nullptr) {
AssignTask(worker, task, &post_assign_callbacks);
Expand Down Expand Up @@ -1145,11 +1147,11 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
ProcessFetchOrReconstructMessage(client, message_data);
} break;
case protocol::MessageType::NotifyDirectCallTaskBlocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
HandleDirectCallTaskBlocked(worker);
} break;
case protocol::MessageType::NotifyDirectCallTaskUnblocked: {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
HandleDirectCallTaskUnblocked(worker);
} break;
case protocol::MessageType::NotifyUnblocked: {
Expand Down Expand Up @@ -1214,8 +1216,8 @@ void NodeManager::ProcessRegisterClientRequestMessage(
WorkerID worker_id = from_flatbuf<WorkerID>(*message->worker_id());
pid_t pid = message->worker_pid();
std::string worker_ip_address = string_from_flatbuf(*message->ip_address());
auto worker = std::make_shared<Worker>(worker_id, language, worker_ip_address, client,
client_call_manager_);
auto worker = std::dynamic_pointer_cast<WorkerInterface>(std::make_shared<Worker>(
worker_id, language, worker_ip_address, client, client_call_manager_));

int assigned_port;
if (message->is_worker()) {
Expand Down Expand Up @@ -1269,7 +1271,7 @@ void NodeManager::ProcessRegisterClientRequestMessage(
void NodeManager::ProcessAnnounceWorkerPortMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
bool is_worker = true;
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (worker == nullptr) {
is_worker = false;
worker = worker_pool_.GetRegisteredDriver(client);
Expand Down Expand Up @@ -1345,11 +1347,11 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca
}

void NodeManager::HandleWorkerAvailable(const std::shared_ptr<ClientConnection> &client) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
HandleWorkerAvailable(worker);
}

void NodeManager::HandleWorkerAvailable(const std::shared_ptr<Worker> &worker) {
void NodeManager::HandleWorkerAvailable(const std::shared_ptr<WorkerInterface> &worker) {
RAY_CHECK(worker);
bool worker_idle = true;

Expand All @@ -1376,7 +1378,7 @@ void NodeManager::HandleWorkerAvailable(const std::shared_ptr<Worker> &worker) {

void NodeManager::ProcessDisconnectClientMessage(
const std::shared_ptr<ClientConnection> &client, bool intentional_disconnect) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
bool is_worker = false, is_driver = false;
if (worker) {
// The client is a worker.
Expand Down Expand Up @@ -1617,7 +1619,8 @@ void NodeManager::ProcessWaitForDirectActorCallArgsRequestMessage(
object_ids, -1, object_ids.size(), false,
[this, client, tag](std::vector<ObjectID> found, std::vector<ObjectID> remaining) {
RAY_CHECK(remaining.empty());
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker =
worker_pool_.GetRegisteredWorker(client);
if (!worker) {
RAY_LOG(ERROR) << "Lost worker for wait request " << client;
} else {
Expand Down Expand Up @@ -1647,7 +1650,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
const auto &actor_entry = actor_registry_.find(actor_id);
RAY_CHECK(actor_entry != actor_registry_.end());

std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
RAY_CHECK(worker && worker->GetActorId() == actor_id);

std::shared_ptr<ActorCheckpointData> checkpoint_data =
Expand Down Expand Up @@ -1822,7 +1825,7 @@ void NodeManager::HandleReturnWorker(const rpc::ReturnWorkerRequest &request,
rpc::SendReplyCallback send_reply_callback) {
// Read the resource spec submitted by the client.
auto worker_id = WorkerID::FromBinary(request.worker_id());
std::shared_ptr<Worker> worker = leased_workers_[worker_id];
std::shared_ptr<WorkerInterface> worker = leased_workers_[worker_id];

Status status;
leased_workers_.erase(worker_id);
Expand Down Expand Up @@ -2320,7 +2323,8 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
}
}

void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker) {
void NodeManager::HandleDirectCallTaskBlocked(
const std::shared_ptr<WorkerInterface> &worker) {
if (new_scheduler_enabled_) {
if (!worker) {
return;
Expand Down Expand Up @@ -2349,7 +2353,8 @@ void NodeManager::HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &wor
DispatchTasks(local_queues_.GetReadyTasksByClass());
}

void NodeManager::HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker) {
void NodeManager::HandleDirectCallTaskUnblocked(
const std::shared_ptr<WorkerInterface> &worker) {
if (new_scheduler_enabled_) {
if (!worker) {
return;
Expand Down Expand Up @@ -2406,7 +2411,7 @@ void NodeManager::AsyncResolveObjects(
const std::shared_ptr<ClientConnection> &client,
const std::vector<rpc::ObjectReference> &required_object_refs,
const TaskID &current_task_id, bool ray_get, bool mark_worker_blocked) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);
if (worker) {
// The client is a worker. If the worker is not already blocked and the
// blocked task matches the one assigned to the worker, then mark the
Expand Down Expand Up @@ -2460,7 +2465,7 @@ void NodeManager::AsyncResolveObjects(
void NodeManager::AsyncResolveObjectsFinish(
const std::shared_ptr<ClientConnection> &client, const TaskID &current_task_id,
bool was_blocked) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> worker = worker_pool_.GetRegisteredWorker(client);

// TODO(swang): Because the object dependencies are tracked in the task
// dependency manager, we could actually remove this message entirely and
Expand Down Expand Up @@ -2540,7 +2545,8 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) {
task_dependency_manager_.TaskPending(task);
}

void NodeManager::AssignTask(const std::shared_ptr<Worker> &worker, const Task &task,
void NodeManager::AssignTask(const std::shared_ptr<WorkerInterface> &worker,
const Task &task,
std::vector<std::function<void()>> *post_assign_callbacks) {
const TaskSpecification &spec = task.GetTaskSpecification();
RAY_CHECK(post_assign_callbacks);
Expand Down Expand Up @@ -2626,7 +2632,7 @@ void NodeManager::AssignTask(const std::shared_ptr<Worker> &worker, const Task &
}
}

bool NodeManager::FinishAssignedTask(Worker &worker) {
bool NodeManager::FinishAssignedTask(WorkerInterface &worker) {
TaskID task_id = worker.GetAssignedTaskId();
RAY_LOG(DEBUG) << "Finished task " << task_id;

Expand Down Expand Up @@ -2735,7 +2741,7 @@ std::shared_ptr<ActorTableData> NodeManager::CreateActorTableDataFromCreationTas
return actor_info_ptr;
}

void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) {
void NodeManager::FinishAssignedActorTask(WorkerInterface &worker, const Task &task) {
RAY_LOG(DEBUG) << "Finishing assigned actor task";
ActorID actor_id;
TaskID caller_id;
Expand Down Expand Up @@ -3303,7 +3309,7 @@ void NodeManager::ForwardTask(
});
}

void NodeManager::FinishAssignTask(const std::shared_ptr<Worker> &worker,
void NodeManager::FinishAssignTask(const std::shared_ptr<WorkerInterface> &worker,
const TaskID &task_id, bool success) {
RAY_LOG(DEBUG) << "FinishAssignTask: " << task_id;
// Remove the ASSIGNED task from the READY queue.
Expand Down Expand Up @@ -3348,7 +3354,8 @@ void NodeManager::FinishAssignTask(const std::shared_ptr<Worker> &worker,

void NodeManager::ProcessSubscribePlasmaReady(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
std::shared_ptr<Worker> associated_worker = worker_pool_.GetRegisteredWorker(client);
std::shared_ptr<WorkerInterface> associated_worker =
worker_pool_.GetRegisteredWorker(client);
if (associated_worker == nullptr) {
associated_worker = worker_pool_.GetRegisteredDriver(client);
}
Expand All @@ -3361,7 +3368,7 @@ void NodeManager::ProcessSubscribePlasmaReady(
absl::MutexLock guard(&plasma_object_notification_lock_);
if (!async_plasma_objects_notification_.contains(id)) {
async_plasma_objects_notification_.emplace(
id, absl::flat_hash_set<std::shared_ptr<Worker>>());
id, absl::flat_hash_set<std::shared_ptr<WorkerInterface>>());
}

// Only insert a worker once
Expand All @@ -3375,7 +3382,7 @@ ray::Status NodeManager::SetupPlasmaSubscription() {
return object_manager_.SubscribeObjAdded(
[this](const object_manager::protocol::ObjectInfoT &object_info) {
ObjectID object_id = ObjectID::FromBinary(object_info.object_id);
auto waiting_workers = absl::flat_hash_set<std::shared_ptr<Worker>>();
auto waiting_workers = absl::flat_hash_set<std::shared_ptr<WorkerInterface>>();
{
absl::MutexLock guard(&plasma_object_notification_lock_);
auto waiting = this->async_plasma_objects_notification_.extract(object_id);
Expand Down
22 changes: 11 additions & 11 deletions src/ray/raylet/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param[in] task The task in question.
/// \param[out] post_assign_callbacks Vector of callbacks that will be appended
/// to with any logic that should run after the DispatchTasks loop runs.
void AssignTask(const std::shared_ptr<Worker> &worker, const Task &task,
void AssignTask(const std::shared_ptr<WorkerInterface> &worker, const Task &task,
std::vector<std::function<void()>> *post_assign_callbacks);
/// Handle a worker finishing its assigned task.
///
/// \param worker The worker that finished the task.
/// \return Whether the worker should be returned to the idle pool. This is
/// only false for direct actor creation calls, which should never be
/// returned to idle.
bool FinishAssignedTask(Worker &worker);
bool FinishAssignedTask(WorkerInterface &worker);
/// Helper function to produce actor table data for a newly created actor.
///
/// \param task_spec Task specification of the actor creation task that created the
Expand All @@ -276,7 +276,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param worker The worker that finished the task.
/// \param task The actor task or actor creation task.
/// \return Void.
void FinishAssignedActorTask(Worker &worker, const Task &task);
void FinishAssignedActorTask(WorkerInterface &worker, const Task &task);
/// Helper function for handling worker to finish its assigned actor task
/// or actor creation task. Gets invoked when tasks's parent actor is known.
///
Expand Down Expand Up @@ -395,20 +395,20 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// arrive after the worker lease has been returned to the node manager.
///
/// \param worker Shared ptr to the worker, or nullptr if lost.
void HandleDirectCallTaskBlocked(const std::shared_ptr<Worker> &worker);
void HandleDirectCallTaskBlocked(const std::shared_ptr<WorkerInterface> &worker);

/// Handle a direct call task that is unblocked. Note that this callback may
/// arrive after the worker lease has been returned to the node manager.
/// However, it is guaranteed to arrive after DirectCallTaskBlocked.
///
/// \param worker Shared ptr to the worker, or nullptr if lost.
void HandleDirectCallTaskUnblocked(const std::shared_ptr<Worker> &worker);
void HandleDirectCallTaskUnblocked(const std::shared_ptr<WorkerInterface> &worker);

/// Kill a worker.
///
/// \param worker The worker to kill.
/// \return Void.
void KillWorker(std::shared_ptr<Worker> worker);
void KillWorker(std::shared_ptr<WorkerInterface> worker);

/// The callback for handling an actor state transition (e.g., from ALIVE to
/// DEAD), whether as a notification from the actor table or as a handler for
Expand Down Expand Up @@ -495,7 +495,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
///
/// \param worker The pointer to the worker
/// \return Void.
void HandleWorkerAvailable(const std::shared_ptr<Worker> &worker);
void HandleWorkerAvailable(const std::shared_ptr<WorkerInterface> &worker);

/// Handle a client that has disconnected. This can be called multiple times
/// on the same client because this is triggered both when a client
Expand Down Expand Up @@ -582,8 +582,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \param task_id Id of the task.
/// \param success Whether or not assigning the task was successful.
/// \return void.
void FinishAssignTask(const std::shared_ptr<Worker> &worker, const TaskID &task_id,
bool success);
void FinishAssignTask(const std::shared_ptr<WorkerInterface> &worker,
const TaskID &task_id, bool success);

/// Process worker subscribing to plasma.
///
Expand Down Expand Up @@ -762,7 +762,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
remote_node_manager_clients_;

/// Map of workers leased out to direct call clients.
std::unordered_map<WorkerID, std::shared_ptr<Worker>> leased_workers_;
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> leased_workers_;

/// Map from owner worker ID to a list of worker IDs that the owner has a
/// lease on.
Expand Down Expand Up @@ -805,7 +805,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
mutable absl::Mutex plasma_object_notification_lock_;

/// Keeps track of workers waiting for objects
absl::flat_hash_map<ObjectID, absl::flat_hash_set<std::shared_ptr<Worker>>>
absl::flat_hash_map<ObjectID, absl::flat_hash_set<std::shared_ptr<WorkerInterface>>>
async_plasma_objects_notification_ GUARDED_BY(plasma_object_notification_lock_);

/// Objects that are out of scope in the application and that should be freed
Expand Down
10 changes: 5 additions & 5 deletions src/ray/raylet/scheduling/cluster_task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ bool ClusterTaskManager::WaitForTaskArgsRequests(Work work) {
}

void ClusterTaskManager::DispatchScheduledTasksToWorkers(
WorkerPool &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers) {
WorkerPoolInterface &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers) {
// Check every task in task_to_dispatch queue to see
// whether it can be dispatched and ran. This avoids head-of-line
// blocking where a task which cannot be dispatched because
Expand All @@ -94,7 +94,7 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers(
auto spec = task.GetTaskSpecification();
tasks_to_dispatch_.pop_front();

std::shared_ptr<Worker> worker = worker_pool.PopWorker(spec);
std::shared_ptr<WorkerInterface> worker = worker_pool.PopWorker(spec);
if (!worker) {
// No worker available to schedule this task.
// Put the task back in the dispatch queue.
Expand Down Expand Up @@ -148,8 +148,8 @@ void ClusterTaskManager::TasksUnblocked(const std::vector<TaskID> ready_ids) {
}

void ClusterTaskManager::Dispatch(
std::shared_ptr<Worker> worker,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers_,
std::shared_ptr<WorkerInterface> worker,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers_,
const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback) {
reply->mutable_worker_address()->set_ip_address(worker->IpAddress());
Expand Down
15 changes: 8 additions & 7 deletions src/ray/raylet/scheduling/cluster_task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ class ClusterTaskManager {
/// `worker_pool` state will be modified (idle workers will be popped) during
/// dispatching.
void DispatchScheduledTasksToWorkers(
WorkerPool &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers);
WorkerPoolInterface &worker_pool,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers);

/// (Step 1) Queue tasks for scheduling.
/// \param fn: The function used during dispatching.
/// \param task: The incoming task to schedule.
void QueueTask(const Task &task, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback);
rpc::SendReplyCallback);

/// Move tasks from waiting to ready for dispatch. Called when a task's
/// dependencies are resolved.
Expand Down Expand Up @@ -96,10 +96,11 @@ class ClusterTaskManager {
/// \return True if the work can be immediately dispatched.
bool WaitForTaskArgsRequests(Work work);

void Dispatch(std::shared_ptr<Worker> worker,
std::unordered_map<WorkerID, std::shared_ptr<Worker>> &leased_workers_,
const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback);
void Dispatch(
std::shared_ptr<WorkerInterface> worker,
std::unordered_map<WorkerID, std::shared_ptr<WorkerInterface>> &leased_workers_,
const TaskSpecification &task_spec, rpc::RequestWorkerLeaseReply *reply,
rpc::SendReplyCallback send_reply_callback);

void Spillback(ClientID spillback_to, std::string address, int port,
rpc::RequestWorkerLeaseReply *reply,
Expand Down
Loading