You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs_nnx/mnist_tutorial.md
+59-23Lines changed: 59 additions & 23 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -26,15 +26,15 @@ Let’s get started!
26
26
27
27
If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):
28
28
29
-
```{code-cell}
29
+
```{code-cell} ipython3
30
30
# !pip install flax
31
31
```
32
32
33
33
## 2. Load the MNIST dataset
34
34
35
35
First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance.
36
36
37
-
```{code-cell}
37
+
```{code-cell} ipython3
38
38
import tensorflow_datasets as tfds # TFDS to download MNIST.
39
39
import tensorflow as tf # TensorFlow / `tf.data` operations.
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x)))))
93
+
def __call__(self, x, rngs: nnx.Rngs):
94
+
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs))))
95
95
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
96
96
x = x.reshape(x.shape[0], -1) # flatten
97
-
x = nnx.relu(self.dropout2(self.linear1(x)))
97
+
x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs))
98
98
x = self.linear2(x)
99
99
return x
100
100
@@ -108,18 +108,18 @@ nnx.display(model)
108
108
109
109
Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results.
110
110
111
-
```{code-cell}
111
+
```{code-cell} ipython3
112
112
import jax.numpy as jnp # JAX NumPy
113
113
114
-
y = model(jnp.ones((1, 28, 28, 1)))
114
+
y = model(jnp.ones((1, 28, 28, 1)), nnx.Rngs(0))
115
115
y
116
116
```
117
117
118
118
## 4. Create the optimizer and define some metrics
119
119
120
120
In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
121
121
122
-
```{code-cell}
122
+
```{code-cell} ipython3
123
123
import optax
124
124
125
125
learning_rate = 0.005
@@ -144,25 +144,25 @@ In addition to the `loss`, during training and testing you will also get the `lo
144
144
145
145
During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics.
146
146
147
-
```{code-cell}
148
-
def loss_fn(model: CNN, batch):
149
-
logits = model(batch['image'])
147
+
```{code-cell} ipython3
148
+
def loss_fn(model: CNN, batch, rngs):
149
+
logits = model(batch['image'], rngs)
150
150
loss = optax.softmax_cross_entropy_with_integer_labels(
if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.
200
202
# Log the training metrics.
@@ -229,18 +231,18 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
229
231
230
232
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
231
233
232
-
```{code-cell}
234
+
```{code-cell} ipython3
233
235
model.eval() # Switch to evaluation mode.
234
236
235
237
@nnx.jit
236
238
def pred_step(model: CNN, batch):
237
-
logits = model(batch['image'])
239
+
logits = model(batch['image'], None)
238
240
return logits.argmax(axis=1)
239
241
```
240
242
241
243
We call .eval() before inference so Dropout is disabled and BatchNorm uses stored running stats. It is used during inference to suppress gradients and ensure deterministic, resource-efficient output.
242
244
243
-
```{code-cell}
245
+
```{code-cell} ipython3
244
246
test_batch = test_ds.as_numpy_iterator().next()
245
247
pred = pred_step(model, test_batch)
246
248
@@ -251,6 +253,40 @@ for i, ax in enumerate(axs.flatten()):
251
253
ax.axis('off')
252
254
```
253
255
256
+
# 8. Export the model
257
+
258
+
+++
259
+
260
+
Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method.
261
+
262
+
```{code-cell} ipython3
263
+
from orbax.export import JaxModule, ExportManager, ServingConfig
264
+
```
265
+
266
+
```{code-cell} ipython3
267
+
def exported_predict(model, y):
268
+
return model(y, None)
269
+
270
+
jax_module = JaxModule(model, exported_predict)
271
+
```
272
+
273
+
We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`.
274
+
275
+
```{code-cell} ipython3
276
+
sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]
277
+
```
278
+
279
+
Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class.
Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.
255
291
256
292
Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html).
0 commit comments