Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 78 additions & 10 deletions docs_nnx/mnist_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,8 @@
"\n",
"Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n",
"\n",
"Let’s get started!"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"Let’s get started!\n",
"\n",
"## 1. Install Flax\n",
"\n",
"If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):"
Expand Down Expand Up @@ -263,7 +257,7 @@
"\n",
"In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n",
"\n",
"In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n",
"In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.\n",
"\n",
"During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics."
]
Expand Down Expand Up @@ -401,7 +395,7 @@
"\n",
"@nnx.jit\n",
"def pred_step(model: CNN, batch):\n",
" logits = model(batch['image'])\n",
" logits = model(batch['image'], None)\n",
" return logits.argmax(axis=1)"
]
},
Expand Down Expand Up @@ -441,6 +435,80 @@
" ax.axis('off')"
]
},
{
"cell_type": "markdown",
"id": "65342ab4",
"metadata": {},
"source": [
"# 8. Export the model\n",
"\n",
"Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "49cace09",
"metadata": {},
"outputs": [],
"source": [
"from orbax.export import JaxModule, ExportManager, ServingConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "421309d4",
"metadata": {},
"outputs": [],
"source": [
"def exported_predict(model, y):\n",
" return model(y, None)\n",
"\n",
"jax_module = JaxModule(model, exported_predict)"
]
},
{
"cell_type": "markdown",
"id": "787136af",
"metadata": {},
"source": [
"We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9f2ad72e",
"metadata": {},
"outputs": [],
"source": [
"sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]"
]
},
{
"cell_type": "markdown",
"id": "31e9668a",
"metadata": {},
"source": [
"Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "18cdf9ad",
"metadata": {},
"outputs": [],
"source": [
"export_mgr = ExportManager(jax_module, [\n",
" ServingConfig('mnist_server', input_signature=sig)\n",
"])\n",
"\n",
"output_dir='/tmp/mnist_export'\n",
"export_mgr.save(output_dir)"
]
},
{
"cell_type": "markdown",
"id": "28",
Expand Down
38 changes: 34 additions & 4 deletions docs_nnx/mnist_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ Flax NNX is a Python neural network library built upon [JAX](https://github.com/

Let’s get started!

+++

## 1. Install Flax

If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):
Expand Down Expand Up @@ -141,7 +139,7 @@ nnx.display(optimizer)

In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.

In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.
In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.

During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics.

Expand Down Expand Up @@ -237,7 +235,7 @@ model.eval() # Switch to evaluation mode.

@nnx.jit
def pred_step(model: CNN, batch):
logits = model(batch['image'])
logits = model(batch['image'], None)
return logits.argmax(axis=1)
```

Expand All @@ -254,6 +252,38 @@ for i, ax in enumerate(axs.flatten()):
ax.axis('off')
```

# 8. Export the model

Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method.

```{code-cell} ipython3
from orbax.export import JaxModule, ExportManager, ServingConfig
```

```{code-cell} ipython3
def exported_predict(model, y):
return model(y, None)

jax_module = JaxModule(model, exported_predict)
```

We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`.

```{code-cell} ipython3
sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]
```

Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class.

```{code-cell} ipython3
export_mgr = ExportManager(jax_module, [
ServingConfig('mnist_server', input_signature=sig)
])

output_dir='/tmp/mnist_export'
export_mgr.save(output_dir)
```

Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.

Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html).
3 changes: 1 addition & 2 deletions examples/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mn
[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default

```
I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69
I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14
I1009 17:56:42.674334 3280981 train.py:175] epoch: 10, train_loss: 0.0073, train_accuracy: 99.75, test_loss: 0.0294, test_accuracy: 99.25
```

### How to run
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ml_collections import config_flags
import tensorflow as tf

import train
import train # pylint: disable=g-bad-import-order


FLAGS = flags.FLAGS
Expand Down
Loading
Loading