Skip to content

Commit d1a00ec

Browse files
authored
[train][doc] New checkpointing user guide (ray-project#39505)
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
1 parent f438fe5 commit d1a00ec

File tree

11 files changed

+821
-550
lines changed

11 files changed

+821
-550
lines changed

doc/source/train/distributed-tensorflow-keras.rst

+39-36
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,9 @@ appropriately in distributed training.
230230

231231

232232
.. code-block:: python
233-
:emphasize-lines: 23
233+
234+
import os
235+
import tempfile
234236
235237
from ray import train
236238
from ray.train import Checkpoint, ScalingConfig
@@ -254,24 +256,24 @@ appropriately in distributed training.
254256
model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
255257
256258
for epoch in range(config["num_epochs"]):
257-
model.fit(X, Y, batch_size=20)
258-
checkpoint = Checkpoint.from_dict(
259-
dict(epoch=epoch, model_weights=model.get_weights())
260-
)
261-
train.report({}, checkpoint=checkpoint)
259+
history = model.fit(X, Y, batch_size=20)
260+
261+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
262+
model.save(os.path.join(temp_checkpoint_dir, "model.keras"))
263+
checkpoint_dict = os.path.join(temp_checkpoint_dir, "checkpoint.json")
264+
with open(checkpoint_dict, "w") as f:
265+
json.dump({"epoch": epoch}, f)
266+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
267+
268+
train.report({"loss": history.history["loss"][0]}, checkpoint=checkpoint)
262269
263270
trainer = TensorflowTrainer(
264271
train_func,
265272
train_loop_config={"num_epochs": 5},
266273
scaling_config=ScalingConfig(num_workers=2),
267274
)
268275
result = trainer.fit()
269-
270-
print(result.checkpoint.to_dict())
271-
# {'epoch': 4, 'model_weights': [array([[-0.31858477],
272-
# [ 0.03747174],
273-
# [ 0.28266194],
274-
# [ 0.8626015 ]], dtype=float32), array([0.02230084], dtype=float32)], '_timestamp': 1656107383, '_preprocessor': None, '_current_checkpoint_id': 4}
276+
print(result.checkpoint)
275277
276278
By default, checkpoints will be persisted to local disk in the :ref:`log
277279
directory <train-log-dir>` of each run.
@@ -280,7 +282,9 @@ Loading checkpoints
280282
~~~~~~~~~~~~~~~~~~~
281283

282284
.. code-block:: python
283-
:emphasize-lines: 15, 21, 22, 25, 26, 27, 30
285+
286+
import os
287+
import tempfile
284288
285289
from ray import train
286290
from ray.train import Checkpoint, ScalingConfig
@@ -297,37 +301,42 @@ Loading checkpoints
297301
X = np.random.normal(0, 1, size=(n, 4))
298302
Y = np.random.uniform(0, 1, size=(n, 1))
299303
300-
start_epoch = 0
301304
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
302-
303305
with strategy.scope():
304306
# toy neural network : 1-layer
305-
model = tf.keras.Sequential([tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))])
306307
checkpoint = train.get_checkpoint()
307308
if checkpoint:
308-
# assume that we have run the train.report() example
309-
# and successfully save some model weights
310-
checkpoint_dict = checkpoint.to_dict()
311-
model.set_weights(checkpoint_dict.get("model_weights"))
312-
start_epoch = checkpoint_dict.get("epoch", -1) + 1
309+
with checkpoint.as_directory() as checkpoint_dir:
310+
model = tf.keras.models.load_model(
311+
os.path.join(checkpoint_dir, "model.keras")
312+
)
313+
else:
314+
model = tf.keras.Sequential(
315+
[tf.keras.layers.Dense(1, activation="linear", input_shape=(4,))]
316+
)
313317
model.compile(optimizer="Adam", loss="mean_squared_error", metrics=["mse"])
314318
315-
for epoch in range(start_epoch, config["num_epochs"]):
316-
model.fit(X, Y, batch_size=20)
317-
checkpoint = Checkpoint.from_dict(
318-
dict(epoch=epoch, model_weights=model.get_weights())
319-
)
320-
train.report({}, checkpoint=checkpoint)
319+
for epoch in range(config["num_epochs"]):
320+
history = model.fit(X, Y, batch_size=20)
321+
322+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
323+
model.save(os.path.join(temp_checkpoint_dir, "model.keras"))
324+
extra_json = os.path.join(temp_checkpoint_dir, "checkpoint.json")
325+
with open(extra_json, "w") as f:
326+
json.dump({"epoch": epoch}, f)
327+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
328+
329+
train.report({"loss": history.history["loss"][0]}, checkpoint=checkpoint)
321330
322331
trainer = TensorflowTrainer(
323332
train_func,
324-
train_loop_config={"num_epochs": 2},
333+
train_loop_config={"num_epochs": 5},
325334
scaling_config=ScalingConfig(num_workers=2),
326335
)
327-
# save a checkpoint
328336
result = trainer.fit()
337+
print(result.checkpoint)
329338
330-
# load a checkpoint
339+
# Start a new run from a loaded checkpoint
331340
trainer = TensorflowTrainer(
332341
train_func,
333342
train_loop_config={"num_epochs": 5},
@@ -336,12 +345,6 @@ Loading checkpoints
336345
)
337346
result = trainer.fit()
338347
339-
print(result.checkpoint.to_dict())
340-
# {'epoch': 4, 'model_weights': [array([[-0.70056134],
341-
# [-0.8839263 ],
342-
# [-1.0043601 ],
343-
# [-0.61634773]], dtype=float32), array([0.01889327], dtype=float32)], '_timestamp': 1656108446, '_preprocessor': None, '_current_checkpoint_id': 3}
344-
345348
346349
Further reading
347350
---------------

0 commit comments

Comments
 (0)