Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
dd587e4
Implement Actor checkpointing
raulchen Jan 24, 2019
4926385
docs
raulchen Jan 26, 2019
61d85f4
fix
raulchen Jan 26, 2019
0de627a
fix
raulchen Jan 26, 2019
c61ee1d
fix
raulchen Jan 29, 2019
1a0e818
move restore-from-checkpoint to HandleActorStateTransition
raulchen Jan 29, 2019
fb8e9dc
Revert "move restore-from-checkpoint to HandleActorStateTransition"
raulchen Feb 1, 2019
2f446dc
resubmit waiting tasks when actor frontier restored
raulchen Feb 1, 2019
b31f374
add doc about num_actor_checkpoints_to_keep=1
raulchen Feb 1, 2019
8700b07
add num_actor_checkpoints_to_keep to Cython
raulchen Feb 1, 2019
17ebcba
add checkpoint_expired api
raulchen Feb 1, 2019
f4d7bb3
check if actor class is abstract
raulchen Feb 1, 2019
b0ae7dd
change checkpoint_ids to long string
raulchen Feb 1, 2019
6c5c130
implement java
raulchen Feb 1, 2019
e5e216f
Refactor to delay actor creation publish until checkpoint is resumed
stephanie-wang Feb 1, 2019
3e5630f
debug, lint
stephanie-wang Feb 1, 2019
dd1f405
Erase from checkpoints to restore if task fails
stephanie-wang Feb 1, 2019
683002c
fix lint
raulchen Feb 2, 2019
4b11a3e
update comments
raulchen Feb 2, 2019
40c8c1d
avoid duplicated actor notification log
raulchen Feb 2, 2019
4c450f9
fix unintended change
raulchen Feb 2, 2019
a3c4397
add actor_id to checkpoint_expired
raulchen Feb 2, 2019
f236a47
small java updates
raulchen Feb 2, 2019
f3955e7
make checkpoint info per actor
raulchen Feb 2, 2019
b7bae09
lint
stephanie-wang Feb 4, 2019
33f07ab
Remove logging
stephanie-wang Feb 4, 2019
e16a4bb
Remove old actor checkpointing Python code, move new checkpointing co…
stephanie-wang Feb 5, 2019
c80889a
Replace old actor checkpointing tests
stephanie-wang Feb 5, 2019
8bf4b61
Fix test and lint
stephanie-wang Feb 6, 2019
1d64d55
address comments
raulchen Feb 8, 2019
10999c5
consolidate kill_actor
raulchen Feb 8, 2019
ee58192
Merge branch 'master' into actor_checkpoint
stephanie-wang Feb 11, 2019
9c7da6d
Remove __ray_checkpoint__
stephanie-wang Feb 12, 2019
821fb8c
Merge branch 'master' into actor_checkpoint
raulchen Feb 12, 2019
e938845
fix non-ascii char
raulchen Feb 12, 2019
6e3985f
Loosen test checks
stephanie-wang Feb 12, 2019
c70a499
fix java
raulchen Feb 13, 2019
428d8b5
fix sphinx-build
raulchen Feb 13, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"tensorflow.python.client",
"tensorflow.python.util",
"ray.core.generated",
"ray.core.generated.ActorCheckpointIdData",
"ray.core.generated.ClientTableData",
"ray.core.generated.GcsTableEntry",
"ray.core.generated.HeartbeatTableData",
Expand Down
99 changes: 99 additions & 0 deletions java/api/src/main/java/org/ray/api/Checkpointable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package org.ray.api;

import java.util.List;
import org.ray.api.id.UniqueId;

public interface Checkpointable {

class CheckpointContext {

/**
* Actor's ID.
*/
public final UniqueId actorId;
/**
* Number of tasks executed since last checkpoint.
*/
public final int numTasksSinceLastCheckpoint;
/**
* Time elapsed since last checkpoint, in milliseconds.
*/
public final long timeElapsedMsSinceLastCheckpoint;

public CheckpointContext(UniqueId actorId, int numTasksSinceLastCheckpoint,
long timeElapsedMsSinceLastCheckpoint) {
this.actorId = actorId;
this.numTasksSinceLastCheckpoint = numTasksSinceLastCheckpoint;
this.timeElapsedMsSinceLastCheckpoint = timeElapsedMsSinceLastCheckpoint;
}
}

class Checkpoint {

/**
* Checkpoint's ID.
*/
public final UniqueId checkpointId;
/**
* Checkpoint's timestamp.
*/
public final long timestamp;

public Checkpoint(UniqueId checkpointId, long timestamp) {
this.checkpointId = checkpointId;
this.timestamp = timestamp;
}
}

/**
* Whether this actor needs to be checkpointed.
*
* This method will be called after every task. You should implement this callback to decide
* whether this actor needs to be checkpointed at this time, based on the checkpoint context, or
* any other factors.
*
* @param checkpointContext An object that contains info about last checkpoint.
* @return A boolean value that indicates whether this actor needs to be checkpointed.
*/
boolean shouldCheckpoint(CheckpointContext checkpointContext);

/**
* Save a checkpoint to persistent storage.
*
* If `shouldCheckpoint` returns true, this method will be called. You should implement this
* callback to save actor's checkpoint and the given checkpoint id to persistent storage.
*
* @param actorId Actor's ID.
* @param checkpointId An ID that represents this actor's current state in GCS. You should
* save this checkpoint ID together with actor's checkpoint data.
*/
void saveCheckpoint(UniqueId actorId, UniqueId checkpointId);

/**
* Load actor's previous checkpoint, and restore actor's state.
*
* This method will be called when an actor is reconstructed, after actor's constructor. If the
* actor needs to restore from previous checkpoint, this function should restore actor's state and
* return the checkpoint ID. Otherwise, it should do nothing and return null.
*
* @param actorId Actor's ID.
* @param availableCheckpoints A list of available checkpoint IDs and their timestamps, sorted
* by timestamp in descending order. Note, this method must return the ID of one checkpoint in
* this list, or null. Otherwise, an exception will be thrown.
* @return The ID of the checkpoint from which the actor was resumed, or null if the actor should
* restart from the beginning.
*/
UniqueId loadCheckpoint(UniqueId actorId, List<Checkpoint> availableCheckpoints);

/**
* Delete an expired checkpoint;
*
* This method will be called when an checkpoint is expired. You should implement this method to
* delete your application checkpoint data. Note, the maximum number of checkpoints kept in the
* backend can be configured at `RayConfig.num_actor_checkpoints_to_keep`.
*
* @param actorId ID of the actor.
* @param checkpointId ID of the checkpoint that has expired.
*/
void checkpointExpired(UniqueId actorId, UniqueId checkpointId);
}
1 change: 1 addition & 0 deletions java/checkstyle-suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

<suppressions>
<suppress checks="OperatorWrap" files=".*" />
<suppress checks="JavadocParagraph" files=".*" />
<suppress checks="MemberNameCheck" files="PathConfig.java"/>
<suppress checks="MemberNameCheck" files="RayParameters.java"/>
<suppress checks="AbbreviationAsWordInNameCheck" files="RayParameters.java"/>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
package org.ray.runtime;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.generated.ActorCheckpointIdData;
import org.ray.runtime.generated.TablePrefix;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -21,7 +31,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime {

private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);

private RedisClient redisClient = null;
/**
* Redis client of the primary shard.
*/
private RedisClient redisClient;
/**
* Redis clients of all shards.
*/
private List<RedisClient> redisClients;
private RunManager manager = null;

public RayNativeRuntime(RayConfig rayConfig) {
Expand Down Expand Up @@ -69,7 +86,8 @@ public void start() throws Exception {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
}
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);

initRedisClients();

// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
Expand All @@ -88,6 +106,16 @@ public void start() throws Exception {
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
}

private void initRedisClients() {
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null));
List<String> addresses = redisClient.lrange("RedisShards", 0, -1);
Preconditions.checkState(numRedisShards == addresses.size());
redisClients = addresses.stream().map(RedisClient::new)
.collect(Collectors.toList());
redisClients.add(redisClient);
}

@Override
public void shutdown() {
if (null != manager) {
Expand Down Expand Up @@ -116,4 +144,33 @@ private void registerWorker() {
}
}

/**
* Get the available checkpoints for the given actor ID, return a list sorted by checkpoint
* timestamp in descending order.
*/
List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
List<Checkpoint> checkpoints = new ArrayList<>();
// TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over
// all redis shards..
String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
for (RedisClient client : redisClients) {
byte[] result = client.get(key, null);
if (result == null) {
continue;
}
ActorCheckpointIdData data = ActorCheckpointIdData
.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));

UniqueId[] checkpointIds
= UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer());

for (int i = 0; i < checkpointIds.length; i++) {
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
}
break;
}
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
return checkpoints;
}
}
86 changes: 86 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/Worker.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package org.ray.runtime;

import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import org.ray.api.Checkpointable;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
Expand All @@ -17,6 +23,9 @@ public class Worker {

private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);

// TODO(hchen): Use the C++ config.
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;

private final AbstractRayRuntime runtime;

/**
Expand All @@ -34,6 +43,22 @@ public class Worker {
*/
private Exception actorCreationException = null;

/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;

/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;

/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;


public Worker(AbstractRayRuntime runtime) {
this.runtime = runtime;
}
Expand Down Expand Up @@ -80,8 +105,12 @@ public void execute(TaskSpec spec) {
}
// Set result
if (!spec.isActorCreationTask()) {
if (spec.isActorTask()) {
maybeSaveCheckpoint(actor, spec.actorId);
}
runtime.put(returnId, result);
} else {
maybeLoadCheckpoint(result, returnId);
currentActor = result;
currentActorId = returnId;
}
Expand All @@ -98,4 +127,61 @@ public void execute(TaskSpec spec) {
Thread.currentThread().setContextClassLoader(oldLoader);
}
}

private void maybeSaveCheckpoint(Object actor, UniqueId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
return;
}
CheckpointContext checkpointContext = new CheckpointContext(actorId,
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
Checkpointable checkpointable = (Checkpointable) actor;
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId);
checkpointIds.add(checkpointId);
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
checkpointIds.remove(0);
}
checkpointable.saveCheckpoint(actorId, checkpointId);
}

private void maybeLoadCheckpoint(Object actor, UniqueId actorId) {
if (!(actor instanceof Checkpointable)) {
return;
}
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints = ((RayNativeRuntime) runtime)
.getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
return;
}
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
if (checkpointId != null) {
boolean checkpointValid = false;
for (Checkpoint checkpoint : availableCheckpoints) {
if (checkpoint.checkpointId.equals(checkpointId)) {
checkpointValid = true;
break;
}
}
Preconditions.checkArgument(checkpointValid,
"'loadCheckpoint' must return a checkpoint ID that exists in the "
+ "'availableCheckpoints' list, or null.");
runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.ray.runtime.gcs;

import java.util.List;
import java.util.Map;

import org.ray.runtime.util.StringUtil;
Expand Down Expand Up @@ -77,7 +78,11 @@ public byte[] get(byte[] key, byte[] field) {
return jedis.hget(key, field);
}
}

}

public List<String> lrange(String key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.lrange(key, start, end);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.NotImplementedException;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
Expand Down Expand Up @@ -94,4 +95,14 @@ public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
return;
}

@Override
public UniqueId prepareCheckpoint(UniqueId actorId) {
throw new NotImplementedException("Not implemented.");
}

@Override
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
throw new NotImplementedException("Not implemented.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,8 @@ <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
timeoutMs, UniqueId currentTaskId);

void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);

UniqueId prepareCheckpoint(UniqueId actorId);

void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId);
}
Loading