@@ -230,7 +230,9 @@ appropriately in distributed training.
230
230
231
231
232
232
.. code-block :: python
233
- :emphasize- lines: 23
233
+
234
+ import os
235
+ import tempfile
234
236
235
237
from ray import train
236
238
from ray.train import Checkpoint, ScalingConfig
@@ -254,24 +256,24 @@ appropriately in distributed training.
254
256
model.compile(optimizer = " Adam" , loss = " mean_squared_error" , metrics = [" mse" ])
255
257
256
258
for epoch in range (config[" num_epochs" ]):
257
- model.fit(X, Y, batch_size = 20 )
258
- checkpoint = Checkpoint.from_dict(
259
- dict (epoch = epoch, model_weights = model.get_weights())
260
- )
261
- train.report({}, checkpoint = checkpoint)
259
+ history = model.fit(X, Y, batch_size = 20 )
260
+
261
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
262
+ model.save(os.path.join(temp_checkpoint_dir, " model.keras" ))
263
+ checkpoint_dict = os.path.join(temp_checkpoint_dir, " checkpoint.json" )
264
+ with open (checkpoint_dict, " w" ) as f:
265
+ json.dump({" epoch" : epoch}, f)
266
+ checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
267
+
268
+ train.report({" loss" : history.history[" loss" ][0 ]}, checkpoint = checkpoint)
262
269
263
270
trainer = TensorflowTrainer(
264
271
train_func,
265
272
train_loop_config = {" num_epochs" : 5 },
266
273
scaling_config = ScalingConfig(num_workers = 2 ),
267
274
)
268
275
result = trainer.fit()
269
-
270
- print (result.checkpoint.to_dict())
271
- # {'epoch': 4, 'model_weights': [array([[-0.31858477],
272
- # [ 0.03747174],
273
- # [ 0.28266194],
274
- # [ 0.8626015 ]], dtype=float32), array([0.02230084], dtype=float32)], '_timestamp': 1656107383, '_preprocessor': None, '_current_checkpoint_id': 4}
276
+ print (result.checkpoint)
275
277
276
278
By default, checkpoints will be persisted to local disk in the :ref: `log
277
279
directory <train-log-dir>` of each run.
@@ -280,7 +282,9 @@ Loading checkpoints
280
282
~~~~~~~~~~~~~~~~~~~
281
283
282
284
.. code-block :: python
283
- :emphasize- lines: 15 , 21 , 22 , 25 , 26 , 27 , 30
285
+
286
+ import os
287
+ import tempfile
284
288
285
289
from ray import train
286
290
from ray.train import Checkpoint, ScalingConfig
@@ -297,37 +301,42 @@ Loading checkpoints
297
301
X = np.random.normal(0 , 1 , size = (n, 4 ))
298
302
Y = np.random.uniform(0 , 1 , size = (n, 1 ))
299
303
300
- start_epoch = 0
301
304
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
302
-
303
305
with strategy.scope():
304
306
# toy neural network : 1-layer
305
- model = tf.keras.Sequential([tf.keras.layers.Dense(1 , activation = " linear" , input_shape = (4 ,))])
306
307
checkpoint = train.get_checkpoint()
307
308
if checkpoint:
308
- # assume that we have run the train.report() example
309
- # and successfully save some model weights
310
- checkpoint_dict = checkpoint.to_dict()
311
- model.set_weights(checkpoint_dict.get(" model_weights" ))
312
- start_epoch = checkpoint_dict.get(" epoch" , - 1 ) + 1
309
+ with checkpoint.as_directory() as checkpoint_dir:
310
+ model = tf.keras.models.load_model(
311
+ os.path.join(checkpoint_dir, " model.keras" )
312
+ )
313
+ else :
314
+ model = tf.keras.Sequential(
315
+ [tf.keras.layers.Dense(1 , activation = " linear" , input_shape = (4 ,))]
316
+ )
313
317
model.compile(optimizer = " Adam" , loss = " mean_squared_error" , metrics = [" mse" ])
314
318
315
- for epoch in range (start_epoch, config[" num_epochs" ]):
316
- model.fit(X, Y, batch_size = 20 )
317
- checkpoint = Checkpoint.from_dict(
318
- dict (epoch = epoch, model_weights = model.get_weights())
319
- )
320
- train.report({}, checkpoint = checkpoint)
319
+ for epoch in range (config[" num_epochs" ]):
320
+ history = model.fit(X, Y, batch_size = 20 )
321
+
322
+ with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
323
+ model.save(os.path.join(temp_checkpoint_dir, " model.keras" ))
324
+ extra_json = os.path.join(temp_checkpoint_dir, " checkpoint.json" )
325
+ with open (extra_json, " w" ) as f:
326
+ json.dump({" epoch" : epoch}, f)
327
+ checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
328
+
329
+ train.report({" loss" : history.history[" loss" ][0 ]}, checkpoint = checkpoint)
321
330
322
331
trainer = TensorflowTrainer(
323
332
train_func,
324
- train_loop_config = {" num_epochs" : 2 },
333
+ train_loop_config = {" num_epochs" : 5 },
325
334
scaling_config = ScalingConfig(num_workers = 2 ),
326
335
)
327
- # save a checkpoint
328
336
result = trainer.fit()
337
+ print (result.checkpoint)
329
338
330
- # load a checkpoint
339
+ # Start a new run from a loaded checkpoint
331
340
trainer = TensorflowTrainer(
332
341
train_func,
333
342
train_loop_config = {" num_epochs" : 5 },
@@ -336,12 +345,6 @@ Loading checkpoints
336
345
)
337
346
result = trainer.fit()
338
347
339
- print (result.checkpoint.to_dict())
340
- # {'epoch': 4, 'model_weights': [array([[-0.70056134],
341
- # [-0.8839263 ],
342
- # [-1.0043601 ],
343
- # [-0.61634773]], dtype=float32), array([0.01889327], dtype=float32)], '_timestamp': 1656108446, '_preprocessor': None, '_current_checkpoint_id': 3}
344
-
345
348
346
349
Further reading
347
350
---------------
0 commit comments