Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove task context from python worker #5987

Merged
merged 14 commits into from
Oct 25, 2019
Prev Previous commit
Next Next commit
Use core worker job ID
  • Loading branch information
edoakes committed Oct 23, 2019
commit d656bae5b6eccae2dfd1c0c4f5ac5c1561a8f8e7
29 changes: 0 additions & 29 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -458,28 +458,6 @@ cdef deserialize_args(

return ray.signature.recover_args(args)

# TODO(edoakes): do these checks in core worker!!
"""
cdef _check_worker_state(worker, CTaskType task_type, JobID job_id):
assert worker.current_task_id.is_nil()
assert worker.task_context.task_index == 0
assert worker.task_context.put_index == 1

# If this worker is not an actor, check that `current_job_id`
# was reset when the worker finished the previous task.
if <int>task_type in [<int>TASK_TYPE_NORMAL_TASK,
<int>TASK_TYPE_ACTOR_CREATION_TASK]:
assert worker.current_job_id.is_nil()
# Set the driver ID of the current running task. This is
# needed so that if the task throws an exception, we propagate
# the error message to the correct driver.
else:
# If this worker is an actor, current_job_id wasn't reset.
# Check that current task's driver ID equals the previous
# one.
assert worker.current_job_id == job_id
"""


cdef _store_task_outputs(worker, return_ids, outputs):
for i in range(len(return_ids)):
Expand Down Expand Up @@ -513,7 +491,6 @@ cdef execute_task(
actor_id = ActorID(c_actor_id.Binary())
job_id = JobID(c_job_id.Binary())
task_id = worker.core_worker.get_current_task_id()
worker.current_job_id = job_id

# Automatically restrict the GPUs available to this task.
ray.utils.set_cuda_visible_devices(ray.get_gpu_ids())
Expand Down Expand Up @@ -610,16 +587,10 @@ cdef execute_task(
# Send signal with the error.
ray_signal.send(ray_signal.ErrorSignal(str(failure_object)))

# Reset the state fields so the next task can run.
worker.core_worker.set_current_task_id(TaskID.nil())

# Don't need to reset `current_job_id` if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
if <int>task_type == <int>TASK_TYPE_NORMAL_TASK:
worker.current_job_id = JobID.nil()
worker.core_worker.set_current_job_id(JobID.nil())

# Reset signal counters so that the next task can get
# all past signals.
ray_signal.reset()
Expand Down
15 changes: 7 additions & 8 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,6 @@ def __init__(self):
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
self.function_actor_manager = FunctionActorManager(self)
# Identity of the job that this worker is processing.
# It is a JobID.
self.current_job_id = JobID.nil()
# This event is checked regularly by all of the threads so that they
# know when to exit.
self.threads_stopped = threading.Event()
Expand Down Expand Up @@ -172,6 +169,12 @@ def use_pickle(self):
self.check_connected()
return self.node.use_pickle

@property
def current_job_id(self):
if hasattr(self, "core_worker"):
return self.core_worker.get_current_job_id()
return JobID.nil()

@property
def current_task_id(self):
return self.core_worker.get_current_task_id()
Expand Down Expand Up @@ -1371,7 +1374,6 @@ def connect(node,

if not isinstance(job_id, JobID):
raise TypeError("The type of given job id must be JobID.")
worker.current_job_id = job_id

# All workers start out as non-actors. A worker can be turned into an actor
# after it is created.
Expand Down Expand Up @@ -1470,7 +1472,7 @@ def connect(node,
(mode == SCRIPT_MODE),
node.plasma_store_socket_name,
node.raylet_socket_name,
worker.current_job_id,
job_id,
gcs_options,
node.get_logs_dir_path(),
node.node_ip_address,
Expand Down Expand Up @@ -1578,9 +1580,6 @@ def disconnect(exiting_interpreter=False):
worker.serialization_context_map.clear()

if not exiting_interpreter:
if hasattr(worker, "raylet_client"):
del worker.raylet_client

if hasattr(worker, "core_worker"):
del worker.core_worker

Expand Down
34 changes: 25 additions & 9 deletions src/ray/core_worker/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@ struct WorkerThreadContext {
return current_task_;
}

void SetCurrentTaskId(const TaskID &task_id) {
current_task_id_ = task_id;
task_index_ = 0;
put_index_ = 0;
}
void SetCurrentTaskId(const TaskID &task_id) { current_task_id_ = task_id; }

void SetCurrentTask(const TaskSpecification &task_spec) {
SetCurrentTaskId(task_spec.TaskId());
current_task_ = std::make_shared<const TaskSpecification>(task_spec);
}

void ResetCurrentTask(const TaskSpecification &task_spec) {
SetCurrentTaskId(TaskID::Nil());
task_index_ = 0;
put_index_ = 0;
}

private:
/// The task ID for current task.
TaskID current_task_id_;
Expand Down Expand Up @@ -83,17 +85,31 @@ void WorkerContext::SetCurrentTaskId(const TaskID &task_id) {
}

void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
SetCurrentJobId(task_spec.JobId());
GetThreadContext().SetCurrentTask(task_spec);
if (task_spec.IsActorCreationTask()) {
if (task_spec.IsNormalTask()) {
RAY_CHECK(current_job_id_.IsNil());
SetCurrentJobId(task_spec.JobId());
} else if (task_spec.IsActorCreationTask()) {
RAY_CHECK(current_job_id_.IsNil());
SetCurrentJobId(task_spec.JobId());
RAY_CHECK(current_actor_id_.IsNil());
current_actor_id_ = task_spec.ActorCreationId();
current_actor_use_direct_call_ = task_spec.IsDirectCall();
}
if (task_spec.IsActorTask()) {
} else if (task_spec.IsActorTask()) {
RAY_CHECK(current_job_id_ == task_spec.JobId());
RAY_CHECK(current_actor_id_ == task_spec.ActorId());
} else {
RAY_CHECK(false);
}
}

void WorkerContext::ResetCurrentTask(const TaskSpecification &task_spec) {
GetThreadContext().ResetCurrentTask(task_spec);
if (task_spec.IsNormalTask()) {
SetCurrentJobId(JobID::Nil());
}
}

std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
return GetThreadContext().GetCurrentTask();
}
Expand Down
2 changes: 2 additions & 0 deletions src/ray/core_worker/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class WorkerContext {

void SetCurrentTask(const TaskSpecification &task_spec);

void ResetCurrentTask(const TaskSpecification &task_spec);

std::shared_ptr<const TaskSpecification> GetCurrentTask() const;

const ActorID &GetCurrentActorID() const;
Expand Down
2 changes: 2 additions & 0 deletions src/ray/core_worker/task_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ Status CoreWorkerTaskExecutionInterface::ExecuteTask(
task_spec.GetRequiredResources().GetResourceMap(),
args, arg_reference_ids, return_ids, results);

worker_context_.ResetCurrentTask(task_spec);

// TODO(zhijunfu):
// 1. Check and handle failure.
// 2. Save or load checkpoint.
Expand Down