Skip to content

Commit

Permalink
Adding an option to train to load a checkpoint for fine-tuning.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 347476621
  • Loading branch information
henrykmichalewski authored and copybara-github committed Dec 14, 2020
1 parent 2f490e8 commit 0349c3f
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
output_dir=None, random_seed=None, n_devices=None,
checkpoints_at=None, should_save_checkpoints=True,
should_write_summaries=True,
metrics=None, checkpoint_highest=None, checkpoint_lowest=None):
metrics=None, checkpoint_highest=None,
checkpoint_lowest=None,
init_checkpoint=None):

self._is_chief, _, self._n_devices, rng = (
training.init_host_and_devices(n_devices, random_seed))
Expand All @@ -105,6 +107,10 @@ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
# Setup the model.
model_train = model(mode='train')
model_predict_eval = model(mode='eval')
# Should work for fine-tuning of T5.
if init_checkpoint:
model_train.init_from_file(init_checkpoint, weights_only=True)
model_predict_eval.init_from_file(init_checkpoint, weights_only=True)
self._model_with_loss = tl.Serial(model_train, loss_fn)

# Setup state.
Expand Down Expand Up @@ -523,7 +529,8 @@ def train(output_dir,
checkpoint_lowest=None,
use_loop=True,
loss_chunk_size=0,
use_memory_efficient_trainer=False):
use_memory_efficient_trainer=False,
init_checkpoint=None):
"""Train the model on the inputs.
Args:
Expand Down Expand Up @@ -554,7 +561,8 @@ def train(output_dir,
checkpoint_lowest: save the checkpoint lowest at this metric.
use_loop: whether to use training.Loop instead of Trainer.
loss_chunk_size: int, if > 0 chunk loss into these sizes to save memory.
use_memory_efficient_trainer: whether to use memory-efficient trainer.
use_memory_efficient_trainer: whether to use memory-efficient trainer..
init_checkpoint: a checkpoint for fine tuning.
Returns:
trax.TrainerState or training.Loop if use_loop is True
Expand Down Expand Up @@ -594,10 +602,17 @@ def train(output_dir,
permanent_checkpoint_at = None
if permanent_checkpoints_at is not None:
permanent_checkpoint_at = (lambda step: step in permanent_checkpoints_at)

# Setup the model.
model_train = model(mode='train')
model_predict_eval = model(mode='eval')
if init_checkpoint:
model_train.init_from_file(init_checkpoint, weights_only=True)
model_predict_eval.init_from_file(init_checkpoint, weights_only=True)
loop = training.Loop(
model(mode='train'),
model_train,
[train_task],
eval_model=model(mode='eval'),
eval_model=model_predict_eval,
eval_tasks=[eval_task],
output_dir=output_dir,
checkpoint_at=checkpoint_at,
Expand All @@ -624,7 +639,8 @@ def train(output_dir,
checkpoints_at=checkpoints_at,
metrics=metrics,
checkpoint_lowest=checkpoint_lowest,
checkpoint_highest=checkpoint_highest)
checkpoint_highest=checkpoint_highest,
init_checkpoint=init_checkpoint)

epoch_steps = [steps] # Only training if eval_frequency is 0 or None
if eval_frequency and eval_steps > 0:
Expand Down

0 comments on commit 0349c3f

Please sign in to comment.