Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xray] Track ray.get calls as task dependencies #2362

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import org.ray.api.UniqueID;
import org.ray.core.model.RayParameters;
import org.ray.core.model.WorkerMode;
import org.ray.spi.model.TaskSpec;

public class WorkerContext {
Expand Down Expand Up @@ -35,7 +36,11 @@ public static WorkerContext init(RayParameters params) {

TaskSpec dummy = new TaskSpec();
dummy.parentTaskId = UniqueID.nil;
dummy.taskId = UniqueID.nil;
if (params.worker_mode == WorkerMode.DRIVER) {
dummy.taskId = UniqueID.randomId();
} else {
dummy.taskId = UniqueID.nil;
}
dummy.actorId = UniqueID.nil;
dummy.driverId = params.driver_id;
prepare(dummy, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ public void start(RayParameters params) throws Exception {
WorkerContext.currentWorkerId(),
UniqueID.nil,
isWorker,
WorkerContext.currentTask().taskId,
0
);

Expand Down Expand Up @@ -237,4 +238,4 @@ public Object localCreateActorInActor(byte[] actorId, String className) {
throw new TaskExecutionException(log, e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
private long client = 0;

public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId, UniqueID actorId,
boolean isWorker, long numGpus) {
boolean isWorker, UniqueID driverId, long numGpus) {
client = _init(schedulerSockName, clientId.getBytes(), actorId.getBytes(), isWorker,
numGpus);
driverId.getBytes(), numGpus);
}

private static native long _init(String localSchedulerSocket, byte[] workerId, byte[] actorId,
boolean isWorker, long numGpus);
boolean isWorker, byte[] driverTaskId, long numGpus);

private static native byte[] _computePutId(long client, byte[] taskId, int putIndex);

Expand Down
3 changes: 2 additions & 1 deletion python/ray/global_scheduler/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def setUp(self):
static_resources={"CPU": 10})
# Connect to the scheduler.
local_scheduler_client = local_scheduler.LocalSchedulerClient(
local_scheduler_name, NIL_WORKER_ID, False, False)
local_scheduler_name, NIL_WORKER_ID, False, random_task_id(),
False)
self.local_scheduler_clients.append(local_scheduler_client)
self.local_scheduler_pids.append(p4)

Expand Down
2 changes: 1 addition & 1 deletion python/ray/local_scheduler/test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def setUp(self):
plasma_store_name, use_valgrind=USE_VALGRIND)
# Connect to the scheduler.
self.local_scheduler_client = local_scheduler.LocalSchedulerClient(
scheduler_name, NIL_WORKER_ID, False, False)
scheduler_name, NIL_WORKER_ID, False, random_task_id(), False)

def tearDown(self):
# Check that the processes are still alive.
Expand Down
31 changes: 19 additions & 12 deletions python/ray/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,15 +503,6 @@ def get_object(self, object_ids):
# get them until at least get_timeout_milliseconds
# milliseconds passes, then repeat.
while len(unready_ids) > 0:
for unready_id in unready_ids:
if not self.use_raylet:
self.local_scheduler_client.reconstruct_objects(
[ray.ObjectID(unready_id)], False)
# Do another fetch for objects that aren't available
# locally yet, in case they were evicted since the last
# fetch. We divide the fetch into smaller fetches so as
# to not block the manager for a prolonged period of time
# in a single call.
object_ids_to_fetch = [
plasma.ObjectID(unready_id)
for unready_id in unready_ids.keys()
Expand All @@ -525,6 +516,18 @@ def get_object(self, object_ids):
for i in range(0, len(object_ids_to_fetch),
fetch_request_size):
if not self.use_raylet:
for unready_id in ray_object_ids_to_fetch[i:(
i + fetch_request_size)]:
(self.local_scheduler_client.
reconstruct_objects([unready_id], False))
# Do another fetch for objects that aren't
# available locally yet, in case they were evicted
# since the last fetch. We divide the fetch into
# smaller fetches so as to not block the manager
# for a prolonged period of time in a single call.
# This is only necessary for legacy ray since
# reconstruction and fetch are implemented by
# different processes.
self.plasma_client.fetch(object_ids_to_fetch[i:(
i + fetch_request_size)])
else:
Expand Down Expand Up @@ -2162,9 +2165,6 @@ def connect(info,
else:
local_scheduler_socket = info["raylet_socket_name"]

worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(
local_scheduler_socket, worker.worker_id, is_worker, worker.use_raylet)

# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
if mode in [SCRIPT_MODE, SILENT_MODE]:
Expand Down Expand Up @@ -2219,6 +2219,13 @@ def connect(info,
# Set the driver's current task ID to the task ID assigned to the
# driver task.
worker.current_task_id = driver_task.task_id()
else:
# A non-driver worker begins without an assigned task.
worker.current_task_id = ray.ObjectID(NIL_ID)

worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(
local_scheduler_socket, worker.worker_id, is_worker,
worker.current_task_id, worker.use_raylet)

# Start the import thread
import_thread.ImportThread(worker, mode).start()
Expand Down
2 changes: 2 additions & 0 deletions src/local_scheduler/format/local_scheduler.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ table RegisterClientRequest {
client_id: string;
// The process ID of this worker.
worker_pid: long;
// The driver ID. This is non-nil if the client is a driver.
driver_id: string;
}

table DisconnectClient {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ Java_org_ray_spi_impl_DefaultLocalSchedulerClient__1init(JNIEnv *env,
jbyteArray wid,
jbyteArray actorId,
jboolean isWorker,
jbyteArray driverId,
jlong numGpus) {
// native private static long _init(String localSchedulerSocket,
// byte[] workerId, byte[] actorId, boolean isWorker, long numGpus);
UniqueIdFromJByteArray worker_id(env, wid);
UniqueIdFromJByteArray driver_id(env, driverId);
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
bool use_raylet = false;
auto client = LocalSchedulerConnection_init(nativeString, *worker_id.PID,
isWorker, use_raylet);
auto client = LocalSchedulerConnection_init(
nativeString, *worker_id.PID, isWorker, *driver_id.PID, use_raylet);
env->ReleaseStringUTFChars(sockName, nativeString);
return reinterpret_cast<jlong>(client);
}
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions src/local_scheduler/lib/python/local_scheduler_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ static int PyLocalSchedulerClient_init(PyLocalSchedulerClient *self,
char *socket_name;
UniqueID client_id;
PyObject *is_worker;
JobID driver_id;
PyObject *use_raylet;
if (!PyArg_ParseTuple(args, "sO&OO", &socket_name, PyStringToUniqueID,
&client_id, &is_worker, &use_raylet)) {
if (!PyArg_ParseTuple(args, "sO&OO&O", &socket_name, PyStringToUniqueID,
&client_id, &is_worker, &PyObjectToUniqueID, &driver_id,
&use_raylet)) {
self->local_scheduler_connection = NULL;
return -1;
}
/* Connect to the local scheduler. */
self->local_scheduler_connection = LocalSchedulerConnection_init(
socket_name, client_id, static_cast<bool>(PyObject_IsTrue(is_worker)),
static_cast<bool>(PyObject_IsTrue(use_raylet)));
driver_id, static_cast<bool>(PyObject_IsTrue(use_raylet)));
return 0;
}

Expand Down
6 changes: 4 additions & 2 deletions src/local_scheduler/local_scheduler_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ using MessageType = ray::local_scheduler::protocol::MessageType;

LocalSchedulerConnection *LocalSchedulerConnection_init(
const char *local_scheduler_socket,
UniqueID client_id,
const UniqueID &client_id,
bool is_worker,
const JobID &driver_id,
bool use_raylet) {
LocalSchedulerConnection *result = new LocalSchedulerConnection();
result->use_raylet = use_raylet;
Expand All @@ -26,7 +27,8 @@ LocalSchedulerConnection *LocalSchedulerConnection_init(
* worker, we will get killed. */
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::local_scheduler::protocol::CreateRegisterClientRequest(
fbb, is_worker, to_flatbuf(fbb, client_id), getpid());
fbb, is_worker, to_flatbuf(fbb, client_id), getpid(),
to_flatbuf(fbb, driver_id));
fbb.Finish(message);
/* Register the process ID with the local scheduler. */
int success = write_message(
Expand Down
6 changes: 5 additions & 1 deletion src/local_scheduler/local_scheduler_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,20 @@ struct LocalSchedulerConnection {
*
* @param local_scheduler_socket The name of the socket to use to connect to the
* local scheduler.
* @param worker_id A unique ID to represent the worker.
* @param is_worker Whether this client is a worker. If it is a worker, an
* additional message will be sent to register as one.
* @param driver_id The ID of the driver. This is non-nil if the client is a
* driver.
* @param use_raylet True if we should use the raylet code path and false
* otherwise.
* @return The connection information.
*/
LocalSchedulerConnection *LocalSchedulerConnection_init(
const char *local_scheduler_socket,
UniqueID worker_id,
const UniqueID &worker_id,
bool is_worker,
const JobID &driver_id,
bool use_raylet);

/**
Expand Down
3 changes: 2 additions & 1 deletion src/local_scheduler/test/local_scheduler_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ LocalSchedulerMock *LocalSchedulerMock_init(int num_workers,

for (int i = 0; i < num_mock_workers; ++i) {
mock->conns[i] = LocalSchedulerConnection_init(
local_scheduler_socket_name.c_str(), WorkerID::nil(), true, false);
local_scheduler_socket_name.c_str(), WorkerID::nil(), true,
JobID::nil(), false);
}

background_thread.join();
Expand Down
2 changes: 2 additions & 0 deletions src/ray/raylet/format/node_manager.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ table RegisterClientRequest {
client_id: string;
// The process ID of this worker.
worker_pid: long;
// The driver ID. This is non-nil if the client is a driver.
driver_id: string;
}

table RegisterClientReply {
Expand Down
Loading