Skip to content

Commit e931e0c

Browse files
danielsuosamanklesaria
authored andcommitted
[flax:examples:mnist] Update mnist example to use NNX.
PiperOrigin-RevId: 815862166
1 parent 74985b2 commit e931e0c

File tree

7 files changed

+260
-124
lines changed

7 files changed

+260
-124
lines changed

docs_nnx/mnist_tutorial.ipynb

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,8 @@
1414
"\n",
1515
"Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n",
1616
"\n",
17-
"Let’s get started!"
18-
]
19-
},
20-
{
21-
"cell_type": "markdown",
22-
"id": "1",
23-
"metadata": {},
24-
"source": [
17+
"Let’s get started!\n",
18+
"\n",
2519
"## 1. Install Flax\n",
2620
"\n",
2721
"If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):"
@@ -263,7 +257,7 @@
263257
"\n",
264258
"In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n",
265259
"\n",
266-
"In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n",
260+
"In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.\n",
267261
"\n",
268262
"During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics."
269263
]
@@ -401,7 +395,7 @@
401395
"\n",
402396
"@nnx.jit\n",
403397
"def pred_step(model: CNN, batch):\n",
404-
" logits = model(batch['image'])\n",
398+
" logits = model(batch['image'], None)\n",
405399
" return logits.argmax(axis=1)"
406400
]
407401
},
@@ -441,6 +435,80 @@
441435
" ax.axis('off')"
442436
]
443437
},
438+
{
439+
"cell_type": "markdown",
440+
"id": "65342ab4",
441+
"metadata": {},
442+
"source": [
443+
"# 8. Export the model\n",
444+
"\n",
445+
"Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method."
446+
]
447+
},
448+
{
449+
"cell_type": "code",
450+
"execution_count": null,
451+
"id": "49cace09",
452+
"metadata": {},
453+
"outputs": [],
454+
"source": [
455+
"from orbax.export import JaxModule, ExportManager, ServingConfig"
456+
]
457+
},
458+
{
459+
"cell_type": "code",
460+
"execution_count": null,
461+
"id": "421309d4",
462+
"metadata": {},
463+
"outputs": [],
464+
"source": [
465+
"def exported_predict(model, y):\n",
466+
" return model(y, None)\n",
467+
"\n",
468+
"jax_module = JaxModule(model, exported_predict)"
469+
]
470+
},
471+
{
472+
"cell_type": "markdown",
473+
"id": "787136af",
474+
"metadata": {},
475+
"source": [
476+
"We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`."
477+
]
478+
},
479+
{
480+
"cell_type": "code",
481+
"execution_count": null,
482+
"id": "9f2ad72e",
483+
"metadata": {},
484+
"outputs": [],
485+
"source": [
486+
"sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]"
487+
]
488+
},
489+
{
490+
"cell_type": "markdown",
491+
"id": "31e9668a",
492+
"metadata": {},
493+
"source": [
494+
"Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class."
495+
]
496+
},
497+
{
498+
"cell_type": "code",
499+
"execution_count": null,
500+
"id": "18cdf9ad",
501+
"metadata": {},
502+
"outputs": [],
503+
"source": [
504+
"export_mgr = ExportManager(jax_module, [\n",
505+
" ServingConfig('mnist_server', input_signature=sig)\n",
506+
"])\n",
507+
"\n",
508+
"output_dir='/tmp/mnist_export'\n",
509+
"export_mgr.save(output_dir)"
510+
]
511+
},
444512
{
445513
"cell_type": "markdown",
446514
"id": "28",

docs_nnx/mnist_tutorial.md

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ Flax NNX is a Python neural network library built upon [JAX](https://github.com/
2020

2121
Let’s get started!
2222

23-
+++
24-
2523
## 1. Install Flax
2624

2725
If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):
@@ -141,7 +139,7 @@ nnx.display(optimizer)
141139

142140
In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.
143141

144-
In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.
142+
In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.
145143

146144
During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics.
147145

@@ -237,7 +235,7 @@ model.eval() # Switch to evaluation mode.
237235
238236
@nnx.jit
239237
def pred_step(model: CNN, batch):
240-
logits = model(batch['image'])
238+
logits = model(batch['image'], None)
241239
return logits.argmax(axis=1)
242240
```
243241

@@ -254,6 +252,38 @@ for i, ax in enumerate(axs.flatten()):
254252
ax.axis('off')
255253
```
256254

255+
# 8. Export the model
256+
257+
Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method.
258+
259+
```{code-cell} ipython3
260+
from orbax.export import JaxModule, ExportManager, ServingConfig
261+
```
262+
263+
```{code-cell} ipython3
264+
def exported_predict(model, y):
265+
return model(y, None)
266+
267+
jax_module = JaxModule(model, exported_predict)
268+
```
269+
270+
We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`.
271+
272+
```{code-cell} ipython3
273+
sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]
274+
```
275+
276+
Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class.
277+
278+
```{code-cell} ipython3
279+
export_mgr = ExportManager(jax_module, [
280+
ServingConfig('mnist_server', input_signature=sig)
281+
])
282+
283+
output_dir='/tmp/mnist_export'
284+
export_mgr.save(output_dir)
285+
```
286+
257287
Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.
258288

259289
Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html).

examples/mnist/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn
2020
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default
2121

2222
```
23-
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
24-
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
23+
I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25
2524
```
2625

2726
### How to run

examples/mnist/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ml_collections import config_flags
2727
import tensorflow as tf
2828

29-
import train
29+
import train # pylint: disable=g-bad-import-order
3030

3131

3232
FLAGS = flags.FLAGS

0 commit comments

Comments
 (0)