Skip to content

Commit f467b56

Browse files
danielsuosamanklesaria
authored andcommitted
[flax:examples:mnist] Update mnist example to use NNX.
PiperOrigin-RevId: 815862166
1 parent 75fd8fa commit f467b56

File tree

6 files changed

+369
-174
lines changed

6 files changed

+369
-174
lines changed

docs_nnx/mnist_tutorial.ipynb

Lines changed: 161 additions & 45 deletions
Large diffs are not rendered by default.

docs_nnx/mnist_tutorial.md

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ Let’s get started!
2626

2727
If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):
2828

29-
```{code-cell}
29+
```{code-cell} ipython3
3030
# !pip install flax
3131
```
3232

3333
## 2. Load the MNIST dataset
3434

3535
First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance.
3636

37-
```{code-cell}
37+
```{code-cell} ipython3
3838
import tensorflow_datasets as tfds # TFDS to download MNIST.
3939
import tensorflow as tf # TensorFlow / `tf.data` operations.
4040
@@ -72,7 +72,7 @@ test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
7272

7373
Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:
7474

75-
```{code-cell}
75+
```{code-cell} ipython3
7676
from flax import nnx # The Flax NNX API.
7777
from functools import partial
7878
@@ -82,19 +82,19 @@ class CNN(nnx.Module):
8282
def __init__(self, *, rngs: nnx.Rngs):
8383
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
8484
self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs)
85-
self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs)
85+
self.dropout1 = nnx.Dropout(rate=0.025)
8686
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
8787
self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs)
8888
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
8989
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
90-
self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs)
90+
self.dropout2 = nnx.Dropout(rate=0.025)
9191
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
9292
93-
def __call__(self, x):
94-
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x)))))
93+
def __call__(self, x, rngs: nnx.Rngs):
94+
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs))))
9595
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
9696
x = x.reshape(x.shape[0], -1) # flatten
97-
x = nnx.relu(self.dropout2(self.linear1(x)))
97+
x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs))
9898
x = self.linear2(x)
9999
return x
100100
@@ -108,18 +108,18 @@ nnx.display(model)
108108

109109
Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results.
110110

111-
```{code-cell}
111+
```{code-cell} ipython3
112112
import jax.numpy as jnp # JAX NumPy
113113
114-
y = model(jnp.ones((1, 28, 28, 1)))
114+
y = model(jnp.ones((1, 28, 28, 1)), nnx.Rngs(0))
115115
y
116116
```
117117

118118
## 4. Create the optimizer and define some metrics
119119

120120
In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
121121

122-
```{code-cell}
122+
```{code-cell} ipython3
123123
import optax
124124
125125
learning_rate = 0.005
@@ -144,25 +144,25 @@ In addition to the `loss`, during training and testing you will also get the `lo
144144

145145
During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics.
146146

147-
```{code-cell}
148-
def loss_fn(model: CNN, batch):
149-
logits = model(batch['image'])
147+
```{code-cell} ipython3
148+
def loss_fn(model: CNN, batch, rngs):
149+
logits = model(batch['image'], rngs)
150150
loss = optax.softmax_cross_entropy_with_integer_labels(
151151
logits=logits, labels=batch['label']
152152
).mean()
153153
return loss, logits
154154
155155
@nnx.jit
156-
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
156+
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch, rngs):
157157
"""Train for a single step."""
158158
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
159-
(loss, logits), grads = grad_fn(model, batch)
159+
(loss, logits), grads = grad_fn(model, batch, rngs)
160160
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
161-
optimizer.update(grads) # In-place updates.
161+
optimizer.update(model, grads) # In-place updates.
162162
163163
@nnx.jit
164164
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
165-
loss, logits = loss_fn(model, batch)
165+
loss, logits = loss_fn(model, batch, None)
166166
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
167167
```
168168

@@ -177,7 +177,7 @@ Now, you can train the CNN model using batches of data for 10 epochs, evaluate t
177177
on the test set after each epoch, and log the training and testing metrics (the loss and
178178
the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy.
179179

180-
```{code-cell}
180+
```{code-cell} ipython3
181181
from IPython.display import clear_output
182182
import matplotlib.pyplot as plt
183183
@@ -188,13 +188,15 @@ metrics_history = {
188188
'test_accuracy': [],
189189
}
190190
191+
rngs = nnx.Rngs(0)
192+
191193
for step, batch in enumerate(train_ds.as_numpy_iterator()):
192194
# Run the optimization for one step and make a stateful update to the following:
193195
# - The train state's model parameters
194196
# - The optimizer state
195197
# - The training loss and accuracy batch metrics
196198
model.train() # Switch to train mode
197-
train_step(model, optimizer, metrics, batch)
199+
train_step(model, optimizer, metrics, batch, rngs)
198200
199201
if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.
200202
# Log the training metrics.
@@ -229,18 +231,18 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
229231

230232
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
231233

232-
```{code-cell}
234+
```{code-cell} ipython3
233235
model.eval() # Switch to evaluation mode.
234236
235237
@nnx.jit
236238
def pred_step(model: CNN, batch):
237-
logits = model(batch['image'])
239+
logits = model(batch['image'], None)
238240
return logits.argmax(axis=1)
239241
```
240242

241243
We call .eval() before inference so Dropout is disabled and BatchNorm uses stored running stats. It is used during inference to suppress gradients and ensure deterministic, resource-efficient output.
242244

243-
```{code-cell}
245+
```{code-cell} ipython3
244246
test_batch = test_ds.as_numpy_iterator().next()
245247
pred = pred_step(model, test_batch)
246248
@@ -251,6 +253,40 @@ for i, ax in enumerate(axs.flatten()):
251253
ax.axis('off')
252254
```
253255

256+
# 8. Export the model
257+
258+
+++
259+
260+
Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method.
261+
262+
```{code-cell} ipython3
263+
from orbax.export import JaxModule, ExportManager, ServingConfig
264+
```
265+
266+
```{code-cell} ipython3
267+
def exported_predict(model, y):
268+
return model(y, None)
269+
270+
jax_module = JaxModule(model, exported_predict)
271+
```
272+
273+
We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`.
274+
275+
```{code-cell} ipython3
276+
sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]
277+
```
278+
279+
Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class.
280+
281+
```{code-cell} ipython3
282+
export_mgr = ExportManager(jax_module, [
283+
ServingConfig('mnist_server', input_signature=sig)
284+
])
285+
286+
output_dir='/tmp/mnist_export'
287+
export_mgr.save(output_dir)
288+
```
289+
254290
Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.
255291

256292
Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html).

examples/mnist/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn
2020
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default
2121

2222
```
23-
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
24-
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
23+
I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25
2524
```
2625

2726
### How to run

examples/mnist/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ml_collections import config_flags
2727
import tensorflow as tf
2828

29-
import train
29+
import train # pylint: disable=g-bad-import-order
3030

3131

3232
FLAGS = flags.FLAGS

0 commit comments

Comments
 (0)