Skip to content

Commit

Permalink
Change the Loop API to support multitask training in the future.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322661439
  • Loading branch information
koz4k authored and copybara-github committed Jul 22, 2020
1 parent 3147439 commit 423d664
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
3 changes: 2 additions & 1 deletion trax/supervised/mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_train_mnist(self):
[tl.CrossEntropyLoss(), tl.Accuracy()],
n_eval_batches=10)

training_session = training.Loop(mnist_model, task, eval_task=eval_task,
training_session = training.Loop(mnist_model, [task],
eval_tasks=[eval_task],
eval_at=lambda step_n: step_n % 50 == 0)

training_session.run(n_steps=1000)
Expand Down
18 changes: 14 additions & 4 deletions trax/supervised/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Loop:
repeatable generation of pseudo-random numbers
"""

def __init__(self, model, task, eval_model=None, eval_task=None,
def __init__(self, model, tasks, eval_model=None, eval_tasks=None,
output_dir=None, checkpoint_at=None, eval_at=None):
"""Configures a training `Loop`, including a random initialization.
Expand All @@ -89,12 +89,14 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
functions and eval functions (a.k.a. metrics) are considered to be
outside the core model, taking core model output and data labels as
their two inputs.
task: TrainTask instance, which defines the training data, loss function,
and optimizer to be used in this training loop.
tasks: List of TrainTask instances, which define the training data, loss
function, and optimizer to be used in respective tasks in this
training loop.
eval_model: Optional Trax layer, representing model used for evaluation,
e.g., with dropout turned off. If None, the training model (model)
will be used.
eval_task: EvalTask instance or None. If None, don't do any evals.
eval_tasks: List of EvalTask instances or None. If None, don't do any
evals.
output_dir: Path telling where to save outputs (evals and checkpoints).
Can be None if both `eval_task` and `checkpoint_at` are None.
checkpoint_at: Function (integer --> boolean) telling, for step n, whether
Expand All @@ -103,6 +105,14 @@ def __init__(self, model, task, eval_model=None, eval_task=None,
eval_at: Function (integer --> boolean) that says, for training step n,
whether that step should run evals. If None, run when checkpointing.
"""
assert len(tasks) == 1, 'Multitask training not supported yet.'
task = tasks[0]
if eval_tasks is None:
eval_task = None
else:
assert len(eval_tasks) == 1, 'Multitask training not supported yet.'
eval_task = eval_tasks[0]

self._task = task
self._model = model
self._eval_model = eval_model or model
Expand Down
14 changes: 7 additions & 7 deletions trax/supervised/training_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_loop_no_eval_task(self):
model = tl.Serial(tl.Dense(1))
task = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01))
training_session = training.Loop(model, task)
training_session = training.Loop(model, [task])
# Loop should initialize and run successfully, even with no eval task.
training_session.run(n_steps=5)

Expand All @@ -65,7 +65,7 @@ def test_train_dense_layer(self):
_very_simple_data(), # deliberately re-using training data
[tl.L2Loss()],
metric_names=['SGD.L2Loss'])
training_session = training.Loop(model, task, eval_task=eval_task,
training_session = training.Loop(model, [task], eval_tasks=[eval_task],
eval_at=lambda step_n: step_n % 2 == 0)
self.assertEqual(0, training_session.current_step)
training_session.run(n_steps=15)
Expand All @@ -82,7 +82,7 @@ def test_train_dense_layer_with_momentum(self):
_very_simple_data(), # deliberately re-using training data
[tl.L2Loss()],
metric_names=['Momentum.L2Loss'])
training_session = training.Loop(model, task, eval_task=eval_task,
training_session = training.Loop(model, [task], eval_tasks=[eval_task],
eval_at=lambda step_n: step_n % 2 == 0)
self.assertEqual(0, training_session.current_step)
training_session.run(n_steps=20)
Expand All @@ -96,7 +96,7 @@ def test_train_dense_layer_evals(self):
eval_task = training.EvalTask(
_very_simple_data(), # deliberately re-using training data
[tl.L2Loss()])
training_session = training.Loop(model, task, eval_task=eval_task,
training_session = training.Loop(model, [task], eval_tasks=[eval_task],
eval_at=lambda step_n: False)
self.assertEqual(0, training_session.current_step)
training_session.run(n_steps=10)
Expand All @@ -114,7 +114,7 @@ def test_summaries_are_written(self):
[tl.L2Loss()],
metric_names=['SGD.L2Loss'])
tmp_dir = self.create_tempdir().full_path
training_session = training.Loop(model, task, eval_task=eval_task,
training_session = training.Loop(model, [task], eval_tasks=[eval_task],
eval_at=lambda step_n: step_n % 2 == 0,
output_dir=tmp_dir)
expected_train_metric_dir = os.path.join(tmp_dir, 'train')
Expand All @@ -141,11 +141,11 @@ def test_restores_step(self):
task = training.TrainTask(
_very_simple_data(), tl.L2Loss(), optimizers.SGD(.01))
tmp_dir = self.create_tempdir().full_path
loop = training.Loop(model, task,
loop = training.Loop(model, [task],
checkpoint_at=lambda step_n: step_n % 2 == 0,
output_dir=tmp_dir)
loop.run(4)
loop2 = training.Loop(model, task, output_dir=tmp_dir)
loop2 = training.Loop(model, [task], output_dir=tmp_dir)
self.assertEqual(4, loop2.current_step)


Expand Down

0 comments on commit 423d664

Please sign in to comment.