Skip to content

Commit

Permalink
[ID Refactor] Shorten the length of JobID to 4 bytes (#5110)
Browse files Browse the repository at this point in the history
* WIP

* Fix

* Add jobid test

* Fix

* Add python part

* Fix

* Fix tes

* Remove TODOs

* Fix C++ tests

* Lint

* Fix

* Fix exporting functions in multiple ray.init

* Fix java test

* Fix lint

* Fix linting

* Address comments.

* FIx

* Address and fix linting

* Refine and fix

* Fix

* address

* Address comments.

* Fix linting

* Fix

* Address

* Address comments.

* Address

* Address

* Fix

* Fix

* Fix

* Fix lint

* Fix

* Fix linting

* Address comments.

* Fix linting

* Address comments.

* Fix linting

* address comments.

* Fix
  • Loading branch information
jovany-wang authored Jul 11, 2019
1 parent 88365d4 commit f229324
Show file tree
Hide file tree
Showing 37 changed files with 386 additions and 133 deletions.
62 changes: 62 additions & 0 deletions java/api/src/main/java/org/ray/api/id/JobId.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.ray.api.id;

import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;

/**
* Represents the id of a Ray job.
*/
public class JobId extends BaseId implements Serializable {

// Note that the max value of a job id is NIL which value is (2^32 - 1).
public static final Long MAX_VALUE = (long) Math.pow(2, 32) - 1;

public static final int LENGTH = 4;

public static final JobId NIL = genNil();

/**
* Create a JobID instance according to the given bytes.
*/
private JobId(byte[] id) {
super(id);
}

/**
* Create a JobId from a given hex string.
*/
public static JobId fromHexString(String hex) {
return new JobId(hexString2Bytes(hex));
}

/**
* Creates a JobId from the given ByteBuffer.
*/
public static JobId fromByteBuffer(ByteBuffer bb) {
return new JobId(byteBuffer2Bytes(bb));
}

public static JobId fromInt(int value) {
byte[] bytes = new byte[JobId.LENGTH];
ByteBuffer wbb = ByteBuffer.wrap(bytes);
wbb.order(ByteOrder.LITTLE_ENDIAN);
wbb.putInt(value);
return new JobId(bytes);
}

/**
* Generate a nil JobId.
*/
private static JobId genNil() {
byte[] b = new byte[LENGTH];
Arrays.fill(b, (byte) 0xFF);
return new JobId(b);
}

@Override
public int size() {
return LENGTH;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.ray.api.runtimecontext;

import java.util.List;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;

/**
Expand All @@ -11,7 +12,7 @@ public interface RuntimeContext {
/**
* Get the current Job ID.
*/
UniqueId getCurrentJobId();
JobId getCurrentJobId();

/**
* Get the current actor ID.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.jobResourcePath);
worker = new Worker(this);
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.jobId, rayConfig.runMode);
runtimeContext = new RuntimeContextImpl(this);
}

Expand Down
13 changes: 13 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.ray.runtime;

import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.objectstore.ObjectStoreProxy;
Expand All @@ -13,9 +15,16 @@ public RayDevRuntime(RayConfig rayConfig) {

private MockObjectStore store;

private AtomicInteger jobCounter = new AtomicInteger(0);

@Override
public void start() {
store = new MockObjectStore(this);
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
objectStoreProxy = new ObjectStoreProxy(this, null);
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
}
Expand All @@ -33,4 +42,8 @@ public MockObjectStore getObjectStore() {
public Worker getWorker() {
return ((MockRayletClient) rayletClient).getCurrentWorker();
}

private JobId nextJobId() {
return JobId.fromInt(jobCounter.getAndIncrement());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.GcsClient;
Expand Down Expand Up @@ -94,6 +95,12 @@ public void start() {

gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);

if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}

workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.getJobId(), rayConfig.runMode);
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.common.base.Preconditions;
import java.util.List;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.api.runtimecontext.RuntimeContext;
Expand All @@ -17,7 +18,7 @@ public RuntimeContextImpl(AbstractRayRuntime runtime) {
}

@Override
public UniqueId getCurrentJobId() {
public JobId getCurrentJobId() {
return runtime.getWorkerContext().getCurrentJobId();
}

Expand Down
14 changes: 7 additions & 7 deletions java/runtime/src/main/java/org/ray/runtime/WorkerContext.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package org.ray.runtime;

import com.google.common.base.Preconditions;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.IdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -29,7 +31,7 @@ public class WorkerContext {

private ThreadLocal<TaskSpec> currentTask;

private UniqueId currentJobId;
private JobId currentJobId;

private ClassLoader currentClassLoader;

Expand All @@ -43,7 +45,7 @@ public class WorkerContext {
*/
private RunMode runMode;

public WorkerContext(WorkerMode workerMode, UniqueId jobId, RunMode runMode) {
public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) {
mainThreadId = Thread.currentThread().getId();
taskIndex = ThreadLocal.withInitial(() -> 0);
putIndex = ThreadLocal.withInitial(() -> 0);
Expand All @@ -52,15 +54,13 @@ public WorkerContext(WorkerMode workerMode, UniqueId jobId, RunMode runMode) {
currentTask = ThreadLocal.withInitial(() -> null);
currentClassLoader = null;
if (workerMode == WorkerMode.DRIVER) {
// TODO(qwang): Assign the driver id to worker id
// once we treat driver id as a special worker id.
workerId = jobId;
workerId = IdUtil.computeDriverId(jobId);
currentTaskId.set(TaskId.randomId());
currentJobId = jobId;
} else {
workerId = UniqueId.randomId();
this.currentTaskId.set(TaskId.NIL);
this.currentJobId = UniqueId.NIL;
this.currentJobId = JobId.NIL;
}
}

Expand Down Expand Up @@ -119,7 +119,7 @@ public UniqueId getCurrentWorkerId() {
/**
* The ID of the current job.
*/
public UniqueId getCurrentJobId() {
public JobId getCurrentJobId() {
return currentJobId;
}

Expand Down
16 changes: 12 additions & 4 deletions java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ray.api.id.UniqueId;
import org.ray.api.id.JobId;
import org.ray.runtime.util.NetworkUtil;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.StringUtil;
Expand All @@ -32,7 +32,7 @@ public class RayConfig {
public final WorkerMode workerMode;
public final RunMode runMode;
public final Map<String, Double> resources;
public final UniqueId jobId;
private JobId jobId;
public final String logDir;
public final boolean redirectOutput;
public final List<String> libraryPath;
Expand Down Expand Up @@ -108,9 +108,9 @@ public RayConfig(Config config) {
// Job id.
String jobId = config.getString("ray.job.id");
if (!jobId.isEmpty()) {
this.jobId = UniqueId.fromHexString(jobId);
this.jobId = JobId.fromHexString(jobId);
} else {
this.jobId = UniqueId.randomId();
this.jobId = JobId.NIL;
}
// Log dir.
logDir = removeTrailingSlash(config.getString("ray.log-dir"));
Expand Down Expand Up @@ -198,6 +198,14 @@ public Integer getRedisPort() {
return redisPort;
}

public void setJobId(JobId jobId) {
this.jobId = jobId;
}

public JobId getJobId() {
return this.jobId;
}

@Override
public String toString() {
return "RayConfig{"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.Type;
import org.ray.api.function.RayFunc;
import org.ray.api.id.UniqueId;
import org.ray.api.id.JobId;
import org.ray.runtime.util.LambdaUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -48,7 +48,7 @@ public class FunctionManager {
/**
* Mapping from the job id to the functions that belong to this job.
*/
private Map<UniqueId, JobFunctionTable> jobFunctionTables = new HashMap<>();
private Map<JobId, JobFunctionTable> jobFunctionTables = new HashMap<>();

/**
* The resource path which we can load the job's jar resources.
Expand All @@ -72,7 +72,7 @@ public FunctionManager(String jobResourcePath) {
* @param func The lambda.
* @return A RayFunction object.
*/
public RayFunction getFunction(UniqueId jobId, RayFunc func) {
public RayFunction getFunction(JobId jobId, RayFunc func) {
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
if (functionDescriptor == null) {
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
Expand All @@ -92,7 +92,7 @@ public RayFunction getFunction(UniqueId jobId, RayFunc func) {
* @param functionDescriptor The function descriptor.
* @return A RayFunction object.
*/
public RayFunction getFunction(UniqueId jobId, JavaFunctionDescriptor functionDescriptor) {
public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) {
JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId);
if (jobFunctionTable == null) {
ClassLoader classLoader;
Expand Down
6 changes: 6 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.id.BaseId;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.runtimecontext.NodeInfo;
Expand Down Expand Up @@ -164,6 +165,11 @@ public List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
return checkpoints;
}

public JobId nextJobId() {
int jobCounter = (int) primary.incr("JobCounter".getBytes());
return JobId.fromInt(jobCounter);
}

private RedisClient getShardClient(BaseId key) {
return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key),
shards.size()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,9 @@ public boolean exists(byte[] key) {
}
}

public long incr(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.incr(key).intValue();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.JobId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
Expand Down Expand Up @@ -164,7 +165,7 @@ public void notifyUnblocked(TaskId currentTaskId) {
}

@Override
public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) {
public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) {
return TaskId.randomId();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.List;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.JobId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
Expand All @@ -21,7 +22,7 @@ public interface RayletClient {

void notifyUnblocked(TaskId currentTaskId);

TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex);
TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex);

<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, TaskId currentTaskId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.id.JobId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
Expand All @@ -39,7 +40,7 @@ public class RayletClientImpl implements RayletClient {

// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
boolean isWorker, UniqueId jobId) {
boolean isWorker, JobId jobId) {
client = nativeInit(schedulerSockName, clientId.getBytes(),
isWorker, jobId.getBytes());
}
Expand Down Expand Up @@ -107,7 +108,7 @@ public void fetchOrReconstruct(List<ObjectId> objectIds, boolean fetchOnly,
}

@Override
public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) {
public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) {
byte[] bytes = nativeGenerateTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex);
return new TaskId(bytes);
}
Expand Down Expand Up @@ -146,7 +147,7 @@ private static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
}

// Parse common fields.
UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
JobId jobId = JobId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer());
TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer());
int parentCounter = (int) taskSpec.getParentCounter();
Expand Down
Loading

0 comments on commit f229324

Please sign in to comment.