Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Fix out-dated signatures of JNI methods #2756

Merged
merged 11 commits into from
Aug 30, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public interface LocalSchedulerLink {

void submitTask(TaskSpec task);

TaskSpec getTaskTodo();
TaskSpec getTask();

void markTaskPutDependency(UniqueID taskId, UniqueID objectId);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private void doSubmit(RayInvocation invocation, UniqueID taskId, UniqueID[] retu
}

public TaskSpec getTask() {
TaskSpec ts = scheduler.getTaskTodo();
TaskSpec ts = scheduler.getTask();
RayLog.core.info("Task " + ts.taskId.toString() + " received");
return ts;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private UniqueID isTaskReady(TaskSpec spec) {
}

@Override
public TaskSpec getTaskTodo() {
public TaskSpec getTask() {
throw new RuntimeException("invalid execution flow here");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,8 @@ public void start(RayParameters params) throws Exception {
LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
params.local_scheduler_name,
WorkerContext.currentWorkerId(),
UniqueID.NIL,
isWorker,
WorkerContext.currentTask().taskId,
0,
false
);

Expand All @@ -133,10 +131,8 @@ public void start(RayParameters params) throws Exception {
LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
params.raylet_socket_name,
WorkerContext.currentWorkerId(),
UniqueID.NIL,
isWorker,
WorkerContext.currentTask().taskId,
0,
true
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,17 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
boolean useRaylet = false;

public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId,
UniqueID actorId, boolean isWorker, UniqueID driverId,
long numGpus, boolean useRaylet) {
client = _init(schedulerSockName, clientId.getBytes(), actorId.getBytes(), isWorker,
driverId.getBytes(), numGpus, useRaylet);
boolean isWorker, UniqueID driverId, boolean useRaylet) {
client = nativeInit(schedulerSockName, clientId.getBytes(),
isWorker, driverId.getBytes(), useRaylet);
this.useRaylet = useRaylet;
}

private static native long _init(String localSchedulerSocket, byte[] workerId,
byte[] actorId, boolean isWorker, byte[] driverTaskId,
long numGpus, boolean useRaylet);

private static native byte[] _computePutId(long client, byte[] taskId, int putIndex);

private static native byte[] _generateTaskId(byte[] driverId, byte[] parentTaskId, int taskIndex);

private static native void _task_done(long client);

private static native boolean[] _waitObject(long conn, byte[][] objectIds,
int numReturns, int timeout, boolean waitLocal);

@Override
public List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns) {
assert (useRaylet == true);

boolean[] readys = _waitObject(client, objectIds, numReturns, timeoutMs, false);
boolean[] readys = nativeWaitObject(client, objectIds, numReturns, timeoutMs, false);
assert (readys.length == objectIds.length);

List<byte[]> ret = new ArrayList<>();
Expand Down Expand Up @@ -91,27 +77,27 @@ public void submitTask(TaskSpec task) {
a = task.cursorId.getBytes();
}

_submitTask(client, a, info, info.position(), info.remaining(), useRaylet);
nativeSubmitTask(client, a, info, info.position(), info.remaining(), useRaylet);
}

@Override
public TaskSpec getTaskTodo() {
byte[] bytes = _getTaskTodo(client, useRaylet);
public TaskSpec getTask() {
byte[] bytes = nativeGetTask(client, useRaylet);
assert (null != bytes);
ByteBuffer bb = ByteBuffer.wrap(bytes);
return taskInfo2Spec(bb);
}

@Override
public void markTaskPutDependency(UniqueID taskId, UniqueID objectId) {
_put_object(client, taskId.getBytes(), objectId.getBytes());
nativePutObject(client, taskId.getBytes(), objectId.getBytes());
}

@Override
public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
List<UniqueID> objects = new ArrayList<>();
objects.add(objectId);
_reconstruct_objects(client, getIdBytes(objects), fetchOnly);
nativeReconstructObjects(client, getIdBytes(objects), fetchOnly);
}

@Override
Expand All @@ -120,30 +106,20 @@ public void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly) {
RayLog.core.info("Reconstructing objects for task {}, object IDs are {}",
UniqueIdHelper.computeTaskId(objectIds.get(0)), objectIds);
}
_reconstruct_objects(client, getIdBytes(objectIds), fetchOnly);
nativeReconstructObjects(client, getIdBytes(objectIds), fetchOnly);
}

@Override
public UniqueID generateTaskId(UniqueID driverId, UniqueID parentTaskId, int taskIndex) {
byte[] bytes = _generateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex);
byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex);
return new UniqueID(bytes);
}

@Override
public void notifyUnblocked() {
_notify_unblocked(client);
nativeNotifyUnblocked(client);
}

private static native void _notify_unblocked(long client);

private static native void _reconstruct_objects(long client, byte[][] objectIds,
boolean fetchOnly);

private static native void _put_object(long client, byte[] taskId, byte[] objectId);

// return TaskInfo (in FlatBuffer)
private static native byte[] _getTaskTodo(long client, boolean useRaylet);

public static TaskSpec taskInfo2Spec(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
Expand Down Expand Up @@ -282,10 +258,6 @@ public static ByteBuffer taskSpec2Info(TaskSpec task) {
return buffer;
}

// task -> TaskInfo (with FlatBuffer)
protected static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task,
int pos, int sz, boolean useRaylet);

private static byte[][] getIdBytes(List<UniqueID> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
Expand All @@ -296,8 +268,45 @@ private static byte[][] getIdBytes(List<UniqueID> objectIds) {
}

public void destroy() {
_destroy(client);
nativeDestroy(client);
}

private static native void _destroy(long client);

/// Native method declarations.
///
/// If you change the signature of any native methods, please re-generate
/// the C++ header file and update the C++ implementation accordingly:
///
/// Suppose that $Dir is your ray root directory.
/// 1) pushd $Dir/java/runtime-native/target/classes
/// 2) javah -classpath .:$Dir/java/runtime-common/target/classes/:$Dir/java/api/target/classes/
/// org.ray.spi.impl.DefaultLocalSchedulerClient
/// 3) cp org_ray_spi_impl_DefaultLocalSchedulerClient.h $Dir/src/local_scheduler/lib/java/
/// 4) vim $Dir/src/local_scheduler/lib/java/org_ray_spi_impl_DefaultLocalSchedulerClient.cc
/// 5) popd

private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
boolean isWorker, byte[] driverTaskId, boolean useRaylet);

private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff,
int pos, int taskSize, boolean useRaylet);

// return TaskInfo (in FlatBuffer)
private static native byte[] nativeGetTask(long client, boolean useRaylet);

private static native void nativeDestroy(long client);

private static native void nativeReconstructObjects(long client, byte[][] objectIds,
boolean fetchOnly);

private static native void nativeNotifyUnblocked(long client);

private static native void nativePutObject(long client, byte[] taskId, byte[] objectId);

private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds,
int numReturns, int timeout, boolean waitLocal);

private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId,
int taskIndex);

}
Loading