Skip to content

Commit

Permalink
add train_step, fit, and compile to the tutorial (keras-team#631)
Browse files Browse the repository at this point in the history
* add train_step

* addressing comments

* Copyedits

* fix typo

Co-authored-by: Haifeng Jin <haifeng-jin@users.noreply.github.com>
Co-authored-by: François Chollet <francois.chollet@gmail.com>
  • Loading branch information
3 people authored Sep 18, 2021
1 parent 615f06b commit c65596b
Showing 1 changed file with 89 additions and 5 deletions.
94 changes: 89 additions & 5 deletions guides/intro_to_keras_for_researchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,95 @@ def call(self, inputs, training=None):
In your research workflows, you may often find yourself mix-and-matching OO models and
Functional models.
Note that the `Model` class also features built-in training & evaluation loops
(`fit()` and `evaluate()`). You can always subclass the `Model` class
(it works exactly like subclassing `Layer`) if you want to leverage these loops
for your OO models.
"""
Note that the `Model` class also features built-in training & evaluation loops:
`fit()`, `predict()` and `evaluate()` (configured via the `compile()` method).
These built-in functions give you access to the
following built-in training infrastructure features:
* [Callbacks](/api/callbacks/). You can leverage built-in
callbacks for early-stopping, model checkpointing,
and monitoring training with TensorBoard. You can also
[implement custom callbacks](/guides/writing_your_own_callbacks/) if needed.
* [Distributed training](https://keras.io/guides/distributed_training/). You
can easily scale up your training to multiple GPUs, TPU, or even multiple machines
with the `tf.distribute` API -- with no changes to your code.
* [Step fusing](https://keras.io/api/models/model_training_apis/#compile-method).
With the `steps_per_execution` argument in `Model.compile()`, you can process
multiple batches in a single `tf.function` call, which greatly improves
device utilization on TPUs.
We won't go into the details, but we provide a simple code example
below. It leverages the built-in training infrastructure to implement the MNIST
example above.
"""

inputs = tf.keras.Input(shape=(784,), dtype="float32")
x = keras.layers.Dense(32, activation="relu")(inputs)
x = keras.layers.Dense(32, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = tf.keras.Model(inputs, outputs)

# Specify the loss, optimizer, and metrics with `compile()`.
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
metrics=[keras.metrics.SparseCategoricalAccuracy()])

# Train the model with the dataset for 2 epochs.
model.fit(dataset, epochs=2)
model.predict(dataset)
model.evaluate(dataset)

"""
You can always subclass the `Model` class (it works exactly like subclassing
`Layer`) if you want to leverage built-in training loops for your OO models.
Just override the `Model.train_step()` to
customize what happens in `fit()` while retaining support
for the built-in infrastructure features outlined above -- callbacks,
zero-code distribution support, and step fusing support.
You may also override `test_step()` to customize what happens in `evaluate()`,
and override `predict_step()` to customize what happens in `predict()`. For more
information, please refer to
[this guide](https://keras.io/guides/customizing_what_happens_in_fit/).
"""

class CustomModel(keras.Model):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss_tracker = keras.metrics.Mean(name="loss")
self.accuracy = keras.metrics.SparseCategoricalAccuracy()
self.loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.optimizer = keras.optimizers.Adam(learning_rate=1e-3)

def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
loss = self.loss_fn(y, y_pred)
gradients = tape.gradient(loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
# Update metrics (includes the metric that tracks the loss)
self.loss_tracker.update_state(loss)
self.accuracy.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {"loss": self.loss_tracker.result(), "accuracy": self.accuracy.result()}

@property
def metrics(self):
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch.
return [self.loss_tracker, self.accuracy]

inputs = tf.keras.Input(shape=(784,), dtype="float32")
x = keras.layers.Dense(32, activation="relu")(inputs)
x = keras.layers.Dense(32, activation="relu")(x)
outputs = keras.layers.Dense(10)(x)
model = CustomModel(inputs, outputs)
model.compile()
model.fit(dataset, epochs=2)

"""
## End-to-end experiment example 1: variational autoencoders.
Expand Down

0 comments on commit c65596b

Please sign in to comment.