Skip to content

Commit f31a79f

Browse files
authored
Implement actor checkpointing (#3839)
* Implement Actor checkpointing * docs * fix * fix * fix * move restore-from-checkpoint to HandleActorStateTransition * Revert "move restore-from-checkpoint to HandleActorStateTransition" This reverts commit 9aa4447c1e3e321f42a1d895d72f17098b72de12. * resubmit waiting tasks when actor frontier restored * add doc about num_actor_checkpoints_to_keep=1 * add num_actor_checkpoints_to_keep to Cython * add checkpoint_expired api * check if actor class is abstract * change checkpoint_ids to long string * implement java * Refactor to delay actor creation publish until checkpoint is resumed * debug, lint * Erase from checkpoints to restore if task fails * fix lint * update comments * avoid duplicated actor notification log * fix unintended change * add actor_id to checkpoint_expired * small java updates * make checkpoint info per actor * lint * Remove logging * Remove old actor checkpointing Python code, move new checkpointing code to FunctionActionManager * Replace old actor checkpointing tests * Fix test and lint * address comments * consolidate kill_actor * Remove __ray_checkpoint__ * fix non-ascii char * Loosen test checks * fix java * fix sphinx-build
1 parent 57dcd30 commit f31a79f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1710
-492
lines changed

doc/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
"tensorflow.python.client",
3737
"tensorflow.python.util",
3838
"ray.core.generated",
39+
"ray.core.generated.ActorCheckpointIdData",
3940
"ray.core.generated.ClientTableData",
4041
"ray.core.generated.GcsTableEntry",
4142
"ray.core.generated.HeartbeatTableData",
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package org.ray.api;
2+
3+
import java.util.List;
4+
import org.ray.api.id.UniqueId;
5+
6+
public interface Checkpointable {
7+
8+
class CheckpointContext {
9+
10+
/**
11+
* Actor's ID.
12+
*/
13+
public final UniqueId actorId;
14+
/**
15+
* Number of tasks executed since last checkpoint.
16+
*/
17+
public final int numTasksSinceLastCheckpoint;
18+
/**
19+
* Time elapsed since last checkpoint, in milliseconds.
20+
*/
21+
public final long timeElapsedMsSinceLastCheckpoint;
22+
23+
public CheckpointContext(UniqueId actorId, int numTasksSinceLastCheckpoint,
24+
long timeElapsedMsSinceLastCheckpoint) {
25+
this.actorId = actorId;
26+
this.numTasksSinceLastCheckpoint = numTasksSinceLastCheckpoint;
27+
this.timeElapsedMsSinceLastCheckpoint = timeElapsedMsSinceLastCheckpoint;
28+
}
29+
}
30+
31+
class Checkpoint {
32+
33+
/**
34+
* Checkpoint's ID.
35+
*/
36+
public final UniqueId checkpointId;
37+
/**
38+
* Checkpoint's timestamp.
39+
*/
40+
public final long timestamp;
41+
42+
public Checkpoint(UniqueId checkpointId, long timestamp) {
43+
this.checkpointId = checkpointId;
44+
this.timestamp = timestamp;
45+
}
46+
}
47+
48+
/**
49+
* Whether this actor needs to be checkpointed.
50+
*
51+
* This method will be called after every task. You should implement this callback to decide
52+
* whether this actor needs to be checkpointed at this time, based on the checkpoint context, or
53+
* any other factors.
54+
*
55+
* @param checkpointContext An object that contains info about last checkpoint.
56+
* @return A boolean value that indicates whether this actor needs to be checkpointed.
57+
*/
58+
boolean shouldCheckpoint(CheckpointContext checkpointContext);
59+
60+
/**
61+
* Save a checkpoint to persistent storage.
62+
*
63+
* If `shouldCheckpoint` returns true, this method will be called. You should implement this
64+
* callback to save actor's checkpoint and the given checkpoint id to persistent storage.
65+
*
66+
* @param actorId Actor's ID.
67+
* @param checkpointId An ID that represents this actor's current state in GCS. You should
68+
* save this checkpoint ID together with actor's checkpoint data.
69+
*/
70+
void saveCheckpoint(UniqueId actorId, UniqueId checkpointId);
71+
72+
/**
73+
* Load actor's previous checkpoint, and restore actor's state.
74+
*
75+
* This method will be called when an actor is reconstructed, after actor's constructor. If the
76+
* actor needs to restore from previous checkpoint, this function should restore actor's state and
77+
* return the checkpoint ID. Otherwise, it should do nothing and return null.
78+
*
79+
* @param actorId Actor's ID.
80+
* @param availableCheckpoints A list of available checkpoint IDs and their timestamps, sorted
81+
* by timestamp in descending order. Note, this method must return the ID of one checkpoint in
82+
* this list, or null. Otherwise, an exception will be thrown.
83+
* @return The ID of the checkpoint from which the actor was resumed, or null if the actor should
84+
* restart from the beginning.
85+
*/
86+
UniqueId loadCheckpoint(UniqueId actorId, List<Checkpoint> availableCheckpoints);
87+
88+
/**
89+
* Delete an expired checkpoint;
90+
*
91+
* This method will be called when an checkpoint is expired. You should implement this method to
92+
* delete your application checkpoint data. Note, the maximum number of checkpoints kept in the
93+
* backend can be configured at `RayConfig.num_actor_checkpoints_to_keep`.
94+
*
95+
* @param actorId ID of the actor.
96+
* @param checkpointId ID of the checkpoint that has expired.
97+
*/
98+
void checkpointExpired(UniqueId actorId, UniqueId checkpointId);
99+
}

java/checkstyle-suppressions.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
<suppressions>
66
<suppress checks="OperatorWrap" files=".*" />
7+
<suppress checks="JavadocParagraph" files=".*" />
78
<suppress checks="MemberNameCheck" files="PathConfig.java"/>
89
<suppress checks="MemberNameCheck" files="RayParameters.java"/>
910
<suppress checks="AbbreviationAsWordInNameCheck" files="RayParameters.java"/>

java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,26 @@
11
package org.ray.runtime;
22

3+
import com.google.common.base.Preconditions;
34
import com.google.common.base.Strings;
45
import java.lang.reflect.Field;
6+
import java.nio.ByteBuffer;
7+
import java.util.ArrayList;
58
import java.util.HashMap;
9+
import java.util.List;
610
import java.util.Map;
711
import java.util.stream.Collectors;
12+
import org.apache.commons.lang3.ArrayUtils;
13+
import org.ray.api.Checkpointable.Checkpoint;
14+
import org.ray.api.id.UniqueId;
815
import org.ray.runtime.config.RayConfig;
916
import org.ray.runtime.config.WorkerMode;
1017
import org.ray.runtime.gcs.RedisClient;
18+
import org.ray.runtime.generated.ActorCheckpointIdData;
19+
import org.ray.runtime.generated.TablePrefix;
1120
import org.ray.runtime.objectstore.ObjectStoreProxy;
1221
import org.ray.runtime.raylet.RayletClientImpl;
1322
import org.ray.runtime.runner.RunManager;
23+
import org.ray.runtime.util.UniqueIdUtil;
1424
import org.slf4j.Logger;
1525
import org.slf4j.LoggerFactory;
1626

@@ -21,7 +31,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
2131

2232
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
2333

24-
private RedisClient redisClient = null;
34+
/**
35+
* Redis client of the primary shard.
36+
*/
37+
private RedisClient redisClient;
38+
/**
39+
* Redis clients of all shards.
40+
*/
41+
private List<RedisClient> redisClients;
2542
private RunManager manager = null;
2643

2744
public RayNativeRuntime(RayConfig rayConfig) {
@@ -69,7 +86,8 @@ public void start() throws Exception {
6986
manager = new RunManager(rayConfig);
7087
manager.startRayProcesses(true);
7188
}
72-
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
89+
90+
initRedisClients();
7391

7492
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
7593
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
@@ -88,6 +106,16 @@ public void start() throws Exception {
88106
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
89107
}
90108

109+
private void initRedisClients() {
110+
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
111+
int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null));
112+
List<String> addresses = redisClient.lrange("RedisShards", 0, -1);
113+
Preconditions.checkState(numRedisShards == addresses.size());
114+
redisClients = addresses.stream().map(RedisClient::new)
115+
.collect(Collectors.toList());
116+
redisClients.add(redisClient);
117+
}
118+
91119
@Override
92120
public void shutdown() {
93121
if (null != manager) {
@@ -116,4 +144,33 @@ private void registerWorker() {
116144
}
117145
}
118146

147+
/**
148+
* Get the available checkpoints for the given actor ID, return a list sorted by checkpoint
149+
* timestamp in descending order.
150+
*/
151+
List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
152+
List<Checkpoint> checkpoints = new ArrayList<>();
153+
// TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over
154+
// all redis shards..
155+
String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
156+
byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
157+
for (RedisClient client : redisClients) {
158+
byte[] result = client.get(key, null);
159+
if (result == null) {
160+
continue;
161+
}
162+
ActorCheckpointIdData data = ActorCheckpointIdData
163+
.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
164+
165+
UniqueId[] checkpointIds
166+
= UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer());
167+
168+
for (int i = 0; i < checkpointIds.length; i++) {
169+
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
170+
}
171+
break;
172+
}
173+
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
174+
return checkpoints;
175+
}
119176
}

java/runtime/src/main/java/org/ray/runtime/Worker.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
package org.ray.runtime;
22

33
import com.google.common.base.Preconditions;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import org.ray.api.Checkpointable;
7+
import org.ray.api.Checkpointable.Checkpoint;
8+
import org.ray.api.Checkpointable.CheckpointContext;
49
import org.ray.api.exception.RayException;
510
import org.ray.api.id.UniqueId;
11+
import org.ray.runtime.config.RunMode;
612
import org.ray.runtime.functionmanager.RayFunction;
713
import org.ray.runtime.task.ArgumentsBuilder;
814
import org.ray.runtime.task.TaskSpec;
@@ -17,6 +23,9 @@ public class Worker {
1723

1824
private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);
1925

26+
// TODO(hchen): Use the C++ config.
27+
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
28+
2029
private final AbstractRayRuntime runtime;
2130

2231
/**
@@ -34,6 +43,22 @@ public class Worker {
3443
*/
3544
private Exception actorCreationException = null;
3645

46+
/**
47+
* Number of tasks executed since last actor checkpoint.
48+
*/
49+
private int numTasksSinceLastCheckpoint = 0;
50+
51+
/**
52+
* IDs of this actor's previous checkpoints.
53+
*/
54+
private List<UniqueId> checkpointIds;
55+
56+
/**
57+
* Timestamp of the last actor checkpoint.
58+
*/
59+
private long lastCheckpointTimestamp = 0;
60+
61+
3762
public Worker(AbstractRayRuntime runtime) {
3863
this.runtime = runtime;
3964
}
@@ -80,8 +105,12 @@ public void execute(TaskSpec spec) {
80105
}
81106
// Set result
82107
if (!spec.isActorCreationTask()) {
108+
if (spec.isActorTask()) {
109+
maybeSaveCheckpoint(actor, spec.actorId);
110+
}
83111
runtime.put(returnId, result);
84112
} else {
113+
maybeLoadCheckpoint(result, returnId);
85114
currentActor = result;
86115
currentActorId = returnId;
87116
}
@@ -98,4 +127,61 @@ public void execute(TaskSpec spec) {
98127
Thread.currentThread().setContextClassLoader(oldLoader);
99128
}
100129
}
130+
131+
private void maybeSaveCheckpoint(Object actor, UniqueId actorId) {
132+
if (!(actor instanceof Checkpointable)) {
133+
return;
134+
}
135+
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
136+
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
137+
return;
138+
}
139+
CheckpointContext checkpointContext = new CheckpointContext(actorId,
140+
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
141+
Checkpointable checkpointable = (Checkpointable) actor;
142+
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
143+
return;
144+
}
145+
numTasksSinceLastCheckpoint = 0;
146+
lastCheckpointTimestamp = System.currentTimeMillis();
147+
UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId);
148+
checkpointIds.add(checkpointId);
149+
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
150+
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
151+
checkpointIds.remove(0);
152+
}
153+
checkpointable.saveCheckpoint(actorId, checkpointId);
154+
}
155+
156+
private void maybeLoadCheckpoint(Object actor, UniqueId actorId) {
157+
if (!(actor instanceof Checkpointable)) {
158+
return;
159+
}
160+
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
161+
// Actor checkpointing isn't implemented for SINGLE_PROCESS mode yet.
162+
return;
163+
}
164+
numTasksSinceLastCheckpoint = 0;
165+
lastCheckpointTimestamp = System.currentTimeMillis();
166+
checkpointIds = new ArrayList<>();
167+
List<Checkpoint> availableCheckpoints = ((RayNativeRuntime) runtime)
168+
.getCheckpointsForActor(actorId);
169+
if (availableCheckpoints.isEmpty()) {
170+
return;
171+
}
172+
UniqueId checkpointId = ((Checkpointable) actor).loadCheckpoint(actorId, availableCheckpoints);
173+
if (checkpointId != null) {
174+
boolean checkpointValid = false;
175+
for (Checkpoint checkpoint : availableCheckpoints) {
176+
if (checkpoint.checkpointId.equals(checkpointId)) {
177+
checkpointValid = true;
178+
break;
179+
}
180+
}
181+
Preconditions.checkArgument(checkpointValid,
182+
"'loadCheckpoint' must return a checkpoint ID that exists in the "
183+
+ "'availableCheckpoints' list, or null.");
184+
runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId);
185+
}
186+
}
101187
}

java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.ray.runtime.gcs;
22

3+
import java.util.List;
34
import java.util.Map;
45

56
import org.ray.runtime.util.StringUtil;
@@ -77,7 +78,11 @@ public byte[] get(byte[] key, byte[] field) {
7778
return jedis.hget(key, field);
7879
}
7980
}
80-
8181
}
8282

83+
public List<String> lrange(String key, long start, long end) {
84+
try (Jedis jedis = jedisPool.getResource()) {
85+
return jedis.lrange(key, start, end);
86+
}
87+
}
8388
}

java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import java.util.List;
55
import java.util.Map;
66
import java.util.concurrent.ConcurrentHashMap;
7+
import org.apache.commons.lang3.NotImplementedException;
78
import org.ray.api.RayObject;
89
import org.ray.api.WaitResult;
910
import org.ray.api.id.UniqueId;
@@ -94,4 +95,14 @@ public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
9495
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
9596
return;
9697
}
98+
99+
@Override
100+
public UniqueId prepareCheckpoint(UniqueId actorId) {
101+
throw new NotImplementedException("Not implemented.");
102+
}
103+
104+
@Override
105+
public void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId) {
106+
throw new NotImplementedException("Not implemented.");
107+
}
97108
}

java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,8 @@ <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
2525
timeoutMs, UniqueId currentTaskId);
2626

2727
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);
28+
29+
UniqueId prepareCheckpoint(UniqueId actorId);
30+
31+
void notifyActorResumedFromCheckpoint(UniqueId actorId, UniqueId checkpointId);
2832
}

0 commit comments

Comments
 (0)