Skip to content

Commit

Permalink
Properly import remote functions and reusable variables on workers th…
Browse files Browse the repository at this point in the history
…at register late (ray-project#290)
  • Loading branch information
robertnishihara authored and pcmoritz committed Jul 25, 2016
1 parent 5591aa4 commit 8e9f98c
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 76 deletions.
28 changes: 12 additions & 16 deletions lib/python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,31 +686,27 @@ def process_task(task): # wrapping these lines in a function should cause the lo
# above so that changes made to their state do not affect other tasks.
ray.reusables._reinitialize()
while True:
(task, function, reusable_variable) = ray.lib.wait_for_next_message(worker.handle)
command, command_args = ray.lib.wait_for_next_message(worker.handle)
try:
# Only one of task, function, and reusable_variable should be not None.
assert sum([obj is not None for obj in [task, function, reusable_variable]]) <= 1
if task is None and function is None and reusable_variable is None:
# We use this as a mechanism to allow the scheduler to kill workers. When
# the scheduler wants to kill a worker, it gives the worker a null task,
# causing the worker program to exit the main loop here.
if command == "die":
# We use this as a mechanism to allow the scheduler to kill workers.
break
if function is not None:
(function, arg_types, return_types) = pickling.loads(function)
elif command == "function":
(function, arg_types, return_types) = pickling.loads(command_args)
if function.__module__ is None: function.__module__ = "__main__"
worker.register_function(remote(arg_types, return_types, worker)(function))
if reusable_variable is not None:
name, initializer_str, reinitializer_str = reusable_variable
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
initializer = pickling.loads(initializer_str)
reinitializer = pickling.loads(reinitializer_str)
reusables.__setattr__(name, Reusable(initializer, reinitializer))
if task is not None:
process_task(task)
elif command == "task":
process_task(command_args)
else:
assert False, "This code should be unreachable."
finally:
# Allow releasing the variables BEFORE we wait for the next message or exit the block
del task
del function
del reusable_variable
del command_args

def _submit_task(func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.
Expand Down
18 changes: 12 additions & 6 deletions protos/ray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,7 @@ message ExportFunctionReply {
}

message ExportReusableVariableRequest {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
ReusableVar reusable_variable = 1; // The reusable variable to export.
}

// These messages are for getting information about the object store state
Expand Down Expand Up @@ -280,13 +278,21 @@ message ImportFunctionReply {
}

message ImportReusableVariableRequest {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
ReusableVar reusable_variable = 1; // The reusable variable to export.
}

message DieRequest {
}

message DieReply {
}

// This message is used by the worker service to send messages to the worker
// that are processed by the worker's main loop.
message WorkerMessage {
oneof worker_item {
Task task = 1; // A task for the worker to execute.
Function function = 2; // A remote function to import on the worker.
ReusableVar reusable_variable = 3; // A reusable variable to import on the worker.
}
}
6 changes: 6 additions & 0 deletions protos/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ message Function {
bytes implementation = 1;
}

message ReusableVar {
string name = 1; // The name of the reusable variable.
Function initializer = 2; // A serialized version of the function that initializes the reusable variable.
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
}

// Union of possible object types
message Obj {
String string_data = 1;
Expand Down
51 changes: 30 additions & 21 deletions src/raylib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,13 @@ static PyObject* serialize_task(PyObject* self, PyObject* args) {
return PyCapsule_New(static_cast<void*>(task), "task", &TaskCapsule_Destructor);
}

static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
static PyObject* deserialize_task(PyObject* worker_capsule, const Task& task) {
std::vector<ObjRef> objrefs; // This is a vector of all the objrefs that were serialized in this task, including objrefs that are contained in Python objects that are passed by value.
PyObject* string = PyString_FromStringAndSize(task->name().c_str(), task->name().size());
int argsize = task->arg_size();
PyObject* string = PyString_FromStringAndSize(task.name().c_str(), task.name().size());
int argsize = task.arg_size();
PyObject* arglist = PyList_New(argsize);
for (int i = 0; i < argsize; ++i) {
const Value& val = task->arg(i);
const Value& val = task.arg(i);
if (!val.has_obj()) {
PyList_SetItem(arglist, i, make_pyobjref(worker_capsule, val.ref()));
objrefs.push_back(val.ref());
Expand All @@ -624,12 +624,12 @@ static PyObject* deserialize_task(PyObject* worker_capsule, Task* task) {
Worker* worker;
PyObjectToWorker(worker_capsule, &worker);
worker->decrement_reference_count(objrefs);
int resultsize = task->result_size();
int resultsize = task.result_size();
std::vector<ObjRef> result_objrefs;
PyObject* resultlist = PyList_New(resultsize);
for (int i = 0; i < resultsize; ++i) {
PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task->result(i)));
result_objrefs.push_back(task->result(i));
PyList_SetItem(resultlist, i, make_pyobjref(worker_capsule, task.result(i)));
result_objrefs.push_back(task.result(i));
}
worker->decrement_reference_count(result_objrefs); // The corresponding increment is done in SubmitTask in the scheduler.
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
Expand Down Expand Up @@ -685,23 +685,32 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
Worker* worker;
PyObjectToWorker(worker_capsule, &worker);
if (std::unique_ptr<WorkerMessage> message = worker->receive_next_message()) {
PyObject* variable_info;
if (!message->reusable_variable.variable_name.empty()) {
variable_info = PyTuple_New(3);
PyTuple_SetItem(variable_info, 0, PyString_FromStringAndSize(message->reusable_variable.variable_name.data(), static_cast<ssize_t>(message->reusable_variable.variable_name.size())));
PyTuple_SetItem(variable_info, 1, PyString_FromStringAndSize(message->reusable_variable.initializer.data(), static_cast<ssize_t>(message->reusable_variable.initializer.size())));
PyTuple_SetItem(variable_info, 2, PyString_FromStringAndSize(message->reusable_variable.reinitializer.data(), static_cast<ssize_t>(message->reusable_variable.reinitializer.size())));
bool task_present = !message->task().name().empty();
bool function_present = !message->function().implementation().empty();
bool reusable_variable_present = !message->reusable_variable().name().empty();
RAY_CHECK(task_present + function_present + reusable_variable_present <= 1, "The worker message should contain at most one item.");
PyObject* t = PyTuple_New(2);
if (task_present) {
PyTuple_SetItem(t, 0, PyString_FromString("task"));
PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task()));
} else if (function_present) {
PyTuple_SetItem(t, 0, PyString_FromString("function"));
PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
} else if (reusable_variable_present) {
PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable"));
PyObject* reusable_variable = PyTuple_New(3);
PyTuple_SetItem(reusable_variable, 0, PyString_FromStringAndSize(message->reusable_variable().name().data(), static_cast<ssize_t>(message->reusable_variable().name().size())));
PyTuple_SetItem(reusable_variable, 1, PyString_FromStringAndSize(message->reusable_variable().initializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().initializer().implementation().size())));
PyTuple_SetItem(reusable_variable, 2, PyString_FromStringAndSize(message->reusable_variable().reinitializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().reinitializer().implementation().size())));
PyTuple_SetItem(t, 1, reusable_variable);
} else {
PyTuple_SetItem(t, 0, PyString_FromString("die"));
Py_INCREF(Py_None);
PyTuple_SetItem(t, 1, Py_None);
}
// The tuple constructed below will take ownership of some None objects.
// When the tuple goes out of scope, the reference count for None will be
// decremented. Therefore, we need to increment the reference count for None
// every time we put a None in the tuple.
PyObject* t = PyTuple_New(3); // We set the items of the tuple using PyTuple_SetItem, because that transfers ownership to the tuple.
PyTuple_SetItem(t, 0, message->task.name().empty() ? Py_INCREF(Py_None), Py_None : deserialize_task(worker_capsule, &message->task));
PyTuple_SetItem(t, 1, message->function.empty() ? Py_INCREF(Py_None), Py_None : PyString_FromStringAndSize(message->function.data(), static_cast<ssize_t>(message->function.size())));
PyTuple_SetItem(t, 2, message->reusable_variable.variable_name.empty() ? Py_INCREF(Py_None), Py_None : variable_info);
return t;
}
RAY_CHECK(false, "This code should be unreachable.");
Py_RETURN_NONE;
}

Expand Down
70 changes: 57 additions & 13 deletions src/scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN
OperationId operationid = (*workers_.get())[workerid].current_task;
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
{
// Check if the worker has been initialized yet, and if not, then give it
// all of the exported functions and all of the exported reusable variables.
auto workers = workers_.get();
if (!(*workers)[workerid].initialized) {
// This should only happen once.
// Import all remote functions on the worker.
export_all_functions_to_worker(workerid, workers, exported_functions_.get());
// Import all reusable variables on the worker.
export_all_reusable_variables_to_worker(workerid, workers, exported_reusable_variables_.get());
// Mark the worker as initialized.
(*workers)[workerid].initialized = true;
}
}
if (request->has_previous_task_info()) {
RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION.");
std::string task_name;
Expand Down Expand Up @@ -293,29 +307,25 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe

Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
auto workers = workers_.get();
auto exported_functions = exported_functions_.get();
// TODO(rkn): Does this do a deep copy?
exported_functions->push_back(std::unique_ptr<Function>(new Function(request->function())));
for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->set_implementation(request->function().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) {
ImportFunctionReply import_reply;
(*workers)[i].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
export_function_to_worker(i, exported_functions->size() - 1, workers, exported_functions);
}
}
return Status::OK;
}

Status SchedulerService::ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) {
auto workers = workers_.get();
auto exported_reusable_variables = exported_reusable_variables_.get();
// TODO(rkn): Does this do a deep copy?
exported_reusable_variables->push_back(std::unique_ptr<ReusableVar>(new ReusableVar(request->reusable_variable())));
for (size_t i = 0; i < workers->size(); ++i) {
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.set_name(request->name());
import_request.mutable_initializer()->set_implementation(request->initializer().implementation());
import_request.mutable_reinitializer()->set_implementation(request->reinitializer().implementation());
if ((*workers)[i].current_task != ROOT_OPERATION) {
AckReply import_reply;
(*workers)[i].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
export_reusable_variable_to_worker(i, exported_reusable_variables->size() - 1, workers, exported_reusable_variables);
}
}
return Status::OK;
Expand Down Expand Up @@ -451,6 +461,7 @@ std::pair<WorkerId, ObjStoreId> SchedulerService::register_worker(const std::str
(*workers)[workerid].objstoreid = objstoreid;
(*workers)[workerid].worker_stub = WorkerService::NewStub(channel);
(*workers)[workerid].worker_address = worker_address;
(*workers)[workerid].initialized = false;
if (is_driver) {
(*workers)[workerid].current_task = ROOT_OPERATION; // We use this field to identify which workers are drivers.
} else {
Expand Down Expand Up @@ -830,6 +841,37 @@ void SchedulerService::get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef>
upstream_objrefs(downstream_objref, equivalent_objrefs, reverse_target_objrefs_.get());
}


void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid);
ClientContext import_context;
ImportFunctionRequest import_request;
import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
ImportFunctionReply import_reply;
(*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
}

void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
RAY_LOG(RAY_INFO, "exporting reusable variable with index " << reusable_variable_index << " to worker " << workerid);
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get());
AckReply import_reply;
(*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply);
}

void SchedulerService::export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
for (int i = 0; i < exported_functions->size(); ++i) {
export_function_to_worker(workerid, i, workers, exported_functions);
}
}

void SchedulerService::export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
for (int i = 0; i < exported_reusable_variables->size(); ++i) {
export_reusable_variable_to_worker(workerid, i, workers, exported_reusable_variables);
}
}

// This method defines the order in which locks should be acquired.
void SchedulerService::do_on_locks(bool lock) {
std::mutex *mutexes[] = {
Expand All @@ -847,7 +889,9 @@ void SchedulerService::do_on_locks(bool lock) {
&objtable_.mutex(),
&objstores_.mutex(),
&target_objrefs_.mutex(),
&reverse_target_objrefs_.mutex()
&reverse_target_objrefs_.mutex(),
&exported_functions_.mutex(),
&exported_reusable_variables_.mutex(),
};
size_t n = sizeof(mutexes) / sizeof(*mutexes);
for (size_t i = 0; i != n; ++i) {
Expand Down
22 changes: 22 additions & 0 deletions src/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ struct WorkerHandle {
std::unique_ptr<WorkerService::Stub> worker_stub; // If null, the worker has died
ObjStoreId objstoreid;
std::string worker_address;
// This field is initialized to false, and it is set to true after all of the
// exported functions and exported reusable variables have been shipped to
// this worker.
bool initialized;
OperationId current_task;
};

Expand Down Expand Up @@ -129,6 +133,20 @@ class SchedulerService : public Scheduler::Service {
void upstream_objrefs(ObjRef objref, std::vector<ObjRef> &objrefs, const SynchronizedPtr<std::vector<std::vector<ObjRef> > > &reverse_target_objrefs);
// Find all of the object references that refer to the same object as objref (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known.
void get_equivalent_objrefs(ObjRef objref, std::vector<ObjRef> &equivalent_objrefs);
// Export a remote function to a worker.
void export_function_to_worker(WorkerId workerid, int function_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export a reusable variable to a worker
void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// Export all reusable variables to a worker. This is used when a new worker
// registers and is protected by the workers lock (which is passed in) to
// ensure that no other reusable variables are exported to the worker while
// this method is being called.
void export_all_functions_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export all remote functions to a worker. This is used when a new worker
// registers and is protected by the workers lock (which is passed in) to
// ensure that no other remote functions are exported to the worker while this
// method is being called.
void export_all_reusable_variables_to_worker(WorkerId workerid, SynchronizedPtr<std::vector<WorkerHandle> > &workers, const SynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// acquires all locks, this should only be used by get_info and for fault tolerance
void acquire_all_locks();
// release all locks, this should only be used by get_info and for fault tolerance
Expand Down Expand Up @@ -187,6 +205,10 @@ class SchedulerService : public Scheduler::Service {
Synchronized<std::vector<RefCount> > reference_counts_;
// contained_objrefs_[objref] is a vector of all of the objrefs contained inside the object referred to by objref
Synchronized<std::vector<std::vector<ObjRef> > > contained_objrefs_;
// All of the remote functions that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<Function> > > exported_functions_;
// All of the reusable variables that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<ReusableVar> > > exported_reusable_variables_;
// the scheduling algorithm that will be used
SchedulingAlgorithmType scheduling_algorithm_;
};
Expand Down
Loading

0 comments on commit 8e9f98c

Please sign in to comment.