Skip to content

Commit

Permalink
[xray] Implements ray.wait (ray-project#2162)
Browse files Browse the repository at this point in the history
Implements ray.wait for xray. Fixes ray-project#1128.
  • Loading branch information
elibol authored Jun 6, 2018
1 parent c8c0349 commit 7246ff8
Show file tree
Hide file tree
Showing 13 changed files with 713 additions and 100 deletions.
42 changes: 28 additions & 14 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2529,6 +2529,11 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
correspond to objects that are stored in the object store. The second list
corresponds to the rest of the object IDs (which may or may not be ready).
Ordering of the input list of object IDs is preserved: if A precedes B in
the input list, and both are in the ready list, then A will precede B in
the ready list. This also holds true if A and B are both in the remaining
list.
Args:
object_ids (List[ObjectID]): List of object IDs for objects that may or
may not be ready. Note that these IDs must be unique.
Expand All @@ -2540,9 +2545,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
A list of object IDs that are ready and a list of the remaining object
IDs.
"""
if worker.use_raylet:
print("plasma_client.wait has not been implemented yet")
return

if isinstance(object_ids, ray.ObjectID):
raise TypeError(
Expand Down Expand Up @@ -2574,18 +2576,30 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if len(object_ids) == 0:
return [], []

object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
if len(object_ids) != len(set(object_ids)):
raise Exception("Wait requires a list of unique object IDs.")
if num_returns <= 0:
raise Exception(
"Invalid number of objects to return %d." % num_returns)
if num_returns > len(object_ids):
raise Exception("num_returns cannot be greater than the number "
"of objects provided to ray.wait.")
timeout = timeout if timeout is not None else 2**30
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
if worker.use_raylet:
ready_ids, remaining_ids = worker.local_scheduler_client.wait(
object_ids, num_returns, timeout, False)
else:
object_id_strs = [
plasma.ObjectID(object_id.id()) for object_id in object_ids
]
ready_ids, remaining_ids = worker.plasma_client.wait(
object_id_strs, timeout, num_returns)
ready_ids = [
ray.ObjectID(object_id.binary()) for object_id in ready_ids
]
remaining_ids = [
ray.ObjectID(object_id.binary()) for object_id in remaining_ids
]
return ready_ids, remaining_ids


Expand Down
54 changes: 54 additions & 0 deletions src/local_scheduler/lib/python/local_scheduler_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,58 @@ static PyObject *PyLocalSchedulerClient_set_actor_frontier(PyObject *self,
Py_RETURN_NONE;
}

static PyObject *PyLocalSchedulerClient_wait(PyObject *self, PyObject *args) {
PyObject *py_object_ids;
int num_returns;
int64_t timeout_ms;
PyObject *py_wait_local;

if (!PyArg_ParseTuple(args, "OilO", &py_object_ids, &num_returns, &timeout_ms,
&py_wait_local)) {
return NULL;
}

bool wait_local = PyObject_IsTrue(py_wait_local);

// Convert object ids.
PyObject *iter = PyObject_GetIter(py_object_ids);
if (!iter) {
return NULL;
}
std::vector<ObjectID> object_ids;
while (true) {
PyObject *next = PyIter_Next(iter);
ObjectID object_id;
if (!next) {
break;
}
if (!PyObjectToUniqueID(next, &object_id)) {
// Error parsing object id.
return NULL;
}
object_ids.push_back(object_id);
}

// Invoke wait.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result =
local_scheduler_wait(reinterpret_cast<PyLocalSchedulerClient *>(self)
->local_scheduler_connection,
object_ids, num_returns, timeout_ms,
static_cast<bool>(wait_local));

// Convert result to py object.
PyObject *py_found = PyList_New(static_cast<Py_ssize_t>(result.first.size()));
for (uint i = 0; i < result.first.size(); ++i) {
PyList_SetItem(py_found, i, PyObjectID_make(result.first[i]));
}
PyObject *py_remaining =
PyList_New(static_cast<Py_ssize_t>(result.second.size()));
for (uint i = 0; i < result.second.size(); ++i) {
PyList_SetItem(py_remaining, i, PyObjectID_make(result.second[i]));
}
return Py_BuildValue("(OO)", py_found, py_remaining);
}

static PyMethodDef PyLocalSchedulerClient_methods[] = {
{"disconnect", (PyCFunction) PyLocalSchedulerClient_disconnect, METH_NOARGS,
"Notify the local scheduler that this client is exiting gracefully."},
Expand All @@ -201,6 +253,8 @@ static PyMethodDef PyLocalSchedulerClient_methods[] = {
(PyCFunction) PyLocalSchedulerClient_get_actor_frontier, METH_VARARGS, ""},
{"set_actor_frontier",
(PyCFunction) PyLocalSchedulerClient_set_actor_frontier, METH_VARARGS, ""},
{"wait", (PyCFunction) PyLocalSchedulerClient_wait, METH_VARARGS,
"Wait for a list of objects to be created."},
{NULL} /* Sentinel */
};

Expand Down
39 changes: 39 additions & 0 deletions src/local_scheduler/local_scheduler_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "common_protocol.h"
#include "format/local_scheduler_generated.h"
#include "ray/raylet/format/node_manager_generated.h"

#include "common/io.h"
#include "common/task.h"
Expand Down Expand Up @@ -207,3 +208,41 @@ void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
ray::local_scheduler::protocol::MessageType_SetActorFrontier,
frontier.size(), const_cast<uint8_t *>(frontier.data()));
}

std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
LocalSchedulerConnection *conn,
const std::vector<ObjectID> &object_ids,
int num_returns,
int64_t timeout_milliseconds,
bool wait_local) {
// Write request.
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWaitRequest(
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds,
wait_local);
fbb.Finish(message);
write_message(conn->conn, ray::protocol::MessageType_WaitRequest,
fbb.GetSize(), fbb.GetBufferPointer());
// Read result.
int64_t type;
int64_t reply_size;
uint8_t *reply;
read_message(conn->conn, &type, &reply_size, &reply);
RAY_CHECK(type == ray::protocol::MessageType_WaitReply);
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply);
// Convert result.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> result;
auto found = reply_message->found();
for (uint i = 0; i < found->size(); i++) {
ObjectID object_id = ObjectID::from_binary(found->Get(i)->str());
result.first.push_back(object_id);
}
auto remaining = reply_message->remaining();
for (uint i = 0; i < remaining->size(); i++) {
ObjectID object_id = ObjectID::from_binary(remaining->Get(i)->str());
result.second.push_back(object_id);
}
/* Free the original message from the local scheduler. */
free(reply);
return result;
}
18 changes: 18 additions & 0 deletions src/local_scheduler/local_scheduler_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,22 @@ const std::vector<uint8_t> local_scheduler_get_actor_frontier(
void local_scheduler_set_actor_frontier(LocalSchedulerConnection *conn,
const std::vector<uint8_t> &frontier);

/// Wait for the given objects until timeout expires or num_return objects are
/// found.
///
/// \param conn The connection information.
/// \param object_ids The objects to wait for.
/// \param num_returns The number of objects to wait for.
/// \param timeout_milliseconds Duration, in milliseconds, to wait before
/// returning.
/// \param wait_local Whether to wait for objects to appear on this node.
/// \return A pair with the first element containing the object ids that were
/// found, and the second element the objects that were not found.
std::pair<std::vector<ObjectID>, std::vector<ObjectID>> local_scheduler_wait(
LocalSchedulerConnection *conn,
const std::vector<ObjectID> &object_ids,
int num_returns,
int64_t timeout_milliseconds,
bool wait_local);

#endif
109 changes: 80 additions & 29 deletions src/ray/object_manager/object_directory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,49 @@ ObjectDirectory::ObjectDirectory(std::shared_ptr<gcs::AsyncGcsClient> &gcs_clien
gcs_client_ = gcs_client;
}

std::vector<ClientID> UpdateObjectLocations(
std::unordered_set<ClientID> &client_ids,
const std::vector<ObjectTableDataT> &location_history) {
// location_history contains the history of locations of the object (it is a log),
// which might look like the following:
// client1.is_eviction = false
// client1.is_eviction = true
// client2.is_eviction = false
// In such a scenario, we want to indicate client2 is the only client that contains
// the object, which the following code achieves.
for (const auto &object_table_data : location_history) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
client_ids.insert(client_id);
} else {
client_ids.erase(client_id);
}
}
return std::vector<ClientID>(client_ids.begin(), client_ids.end());
}

void ObjectDirectory::RegisterBackend() {
auto object_notification_callback = [this](gcs::AsyncGcsClient *client,
const ObjectID &object_id,
const std::vector<ObjectTableDataT> &data) {
auto object_notification_callback = [this](
gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
// Objects are added to this map in SubscribeObjectLocations.
auto entry = listeners_.find(object_id);
auto object_id_listener_pair = listeners_.find(object_id);
// Do nothing for objects we are not listening for.
if (entry == listeners_.end()) {
if (object_id_listener_pair == listeners_.end()) {
return;
}
// Update entries for this object.
auto client_id_set = entry->second.client_ids;
for (auto &object_table_data : data) {
ClientID client_id = ClientID::from_binary(object_table_data.manager);
if (!object_table_data.is_eviction) {
client_id_set.insert(client_id);
} else {
client_id_set.erase(client_id);
std::vector<ClientID> client_id_vec = UpdateObjectLocations(
object_id_listener_pair->second.current_object_locations, location_history);
if (!client_id_vec.empty()) {
// Copy the callbacks so that the callbacks can unsubscribe without interrupting
// looping over the callbacks.
auto callbacks = object_id_listener_pair->second.callbacks;
// Call all callbacks associated with the object id locations we have received.
for (const auto &callback_pair : callbacks) {
callback_pair.second(client_id_vec, object_id);
}
}
if (!client_id_set.empty()) {
// Only call the callback if we have object locations.
std::vector<ClientID> client_id_vec(client_id_set.begin(), client_id_set.end());
auto callback = entry->second.locations_found_callback;
callback(client_id_vec, object_id);
}
};
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
UniqueID::nil(), gcs_client_->client_table().GetLocalClientId(),
Expand Down Expand Up @@ -86,25 +103,59 @@ ray::Status ObjectDirectory::GetInformation(const ClientID &client_id,
return ray::Status::OK();
}

ray::Status ObjectDirectory::SubscribeObjectLocations(const ObjectID &object_id,
ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id,
const OnLocationsFound &callback) {
if (listeners_.find(object_id) != listeners_.end()) {
RAY_LOG(ERROR) << "Duplicate calls to SubscribeObjectLocations for " << object_id;
ray::Status status = ray::Status::OK();
if (listeners_.find(object_id) == listeners_.end()) {
listeners_.emplace(object_id, LocationListenerState());
status = gcs_client_->object_table().RequestNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
}
auto &listener_state = listeners_.find(object_id)->second;
// TODO(hme): Make this fatal after implementing Pull suppression.
if (listener_state.callbacks.count(callback_id) > 0) {
return ray::Status::OK();
}
listeners_.emplace(object_id, LocationListenerState(callback));
return gcs_client_->object_table().RequestNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listener_state.callbacks.emplace(callback_id, callback);
// Immediately notify of found object locations.
if (!listener_state.current_object_locations.empty()) {
std::vector<ClientID> client_id_vec(listener_state.current_object_locations.begin(),
listener_state.current_object_locations.end());
callback(client_id_vec, object_id);
}
return status;
}

ray::Status ObjectDirectory::UnsubscribeObjectLocations(const ObjectID &object_id) {
ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback_id,
const ObjectID &object_id) {
ray::Status status = ray::Status::OK();
auto entry = listeners_.find(object_id);
if (entry == listeners_.end()) {
return ray::Status::OK();
return status;
}
ray::Status status = gcs_client_->object_table().CancelNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listeners_.erase(entry);
entry->second.callbacks.erase(callback_id);
if (entry->second.callbacks.empty()) {
status = gcs_client_->object_table().CancelNotifications(
JobID::nil(), object_id, gcs_client_->client_table().GetLocalClientId());
listeners_.erase(entry);
}
return status;
}

ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
const OnLocationsFound &callback) {
JobID job_id = JobID::nil();
ray::Status status = gcs_client_->object_table().Lookup(
job_id, object_id,
[this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id,
const std::vector<ObjectTableDataT> &location_history) {
// Build the set of current locations based on the entries in the log.
std::unordered_set<ClientID> client_ids;
std::vector<ClientID> locations_vector =
UpdateObjectLocations(client_ids, location_history);
callback(locations_vector, object_id);
});
return status;
}

Expand Down
Loading

0 comments on commit 7246ff8

Please sign in to comment.