|
14 | 14 | "\n", |
15 | 15 | "Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n", |
16 | 16 | "\n", |
17 | | - "Let’s get started!" |
18 | | - ] |
19 | | - }, |
20 | | - { |
21 | | - "cell_type": "markdown", |
22 | | - "id": "1", |
23 | | - "metadata": {}, |
24 | | - "source": [ |
| 17 | + "Let’s get started!\n", |
| 18 | + "\n", |
25 | 19 | "## 1. Install Flax\n", |
26 | 20 | "\n", |
27 | 21 | "If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):" |
|
263 | 257 | "\n", |
264 | 258 | "In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n", |
265 | 259 | "\n", |
266 | | - "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric. \n", |
| 260 | + "In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.\n", |
267 | 261 | "\n", |
268 | 262 | "During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics." |
269 | 263 | ] |
|
401 | 395 | "\n", |
402 | 396 | "@nnx.jit\n", |
403 | 397 | "def pred_step(model: CNN, batch):\n", |
404 | | - " logits = model(batch['image'])\n", |
| 398 | + " logits = model(batch['image'], None)\n", |
405 | 399 | " return logits.argmax(axis=1)" |
406 | 400 | ] |
407 | 401 | }, |
|
441 | 435 | " ax.axis('off')" |
442 | 436 | ] |
443 | 437 | }, |
| 438 | + { |
| 439 | + "cell_type": "markdown", |
| 440 | + "id": "65342ab4", |
| 441 | + "metadata": {}, |
| 442 | + "source": [ |
| 443 | + "# 8. Export the model\n", |
| 444 | + "\n", |
| 445 | + "Flax models are great for research, but aren't meant to be deployed directly. Instead, high performance inference runtimes like LiteRT or TensorFlow Serving operate on a special [SavedModel](https://www.tensorflow.org/guide/saved_model) format. The [Orbax](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library makes it easy to export Flax models to this format. First, we must create a `JaxModule` object wrapping a model and its prediction method." |
| 446 | + ] |
| 447 | + }, |
| 448 | + { |
| 449 | + "cell_type": "code", |
| 450 | + "execution_count": null, |
| 451 | + "id": "49cace09", |
| 452 | + "metadata": {}, |
| 453 | + "outputs": [], |
| 454 | + "source": [ |
| 455 | + "from orbax.export import JaxModule, ExportManager, ServingConfig" |
| 456 | + ] |
| 457 | + }, |
| 458 | + { |
| 459 | + "cell_type": "code", |
| 460 | + "execution_count": null, |
| 461 | + "id": "421309d4", |
| 462 | + "metadata": {}, |
| 463 | + "outputs": [], |
| 464 | + "source": [ |
| 465 | + "def exported_predict(model, y):\n", |
| 466 | + " return model(y, None)\n", |
| 467 | + "\n", |
| 468 | + "jax_module = JaxModule(model, exported_predict)" |
| 469 | + ] |
| 470 | + }, |
| 471 | + { |
| 472 | + "cell_type": "markdown", |
| 473 | + "id": "787136af", |
| 474 | + "metadata": {}, |
| 475 | + "source": [ |
| 476 | + "We also need to tell Tensorflow Serving what input type `exported_predict` expects in its second argument. The export machinery expects type signature arguments to be PyTrees of `tf.TensorSpec`." |
| 477 | + ] |
| 478 | + }, |
| 479 | + { |
| 480 | + "cell_type": "code", |
| 481 | + "execution_count": null, |
| 482 | + "id": "9f2ad72e", |
| 483 | + "metadata": {}, |
| 484 | + "outputs": [], |
| 485 | + "source": [ |
| 486 | + "sig = [tf.TensorSpec(shape=(1, 28, 28, 1), dtype=tf.float32)]" |
| 487 | + ] |
| 488 | + }, |
| 489 | + { |
| 490 | + "cell_type": "markdown", |
| 491 | + "id": "31e9668a", |
| 492 | + "metadata": {}, |
| 493 | + "source": [ |
| 494 | + "Finally, we can bundle up the input signature and the `JaxModule` together using the `ExportManager` class." |
| 495 | + ] |
| 496 | + }, |
| 497 | + { |
| 498 | + "cell_type": "code", |
| 499 | + "execution_count": null, |
| 500 | + "id": "18cdf9ad", |
| 501 | + "metadata": {}, |
| 502 | + "outputs": [], |
| 503 | + "source": [ |
| 504 | + "export_mgr = ExportManager(jax_module, [\n", |
| 505 | + " ServingConfig('mnist_server', input_signature=sig)\n", |
| 506 | + "])\n", |
| 507 | + "\n", |
| 508 | + "output_dir='/tmp/mnist_export'\n", |
| 509 | + "export_mgr.save(output_dir)" |
| 510 | + ] |
| 511 | + }, |
444 | 512 | { |
445 | 513 | "cell_type": "markdown", |
446 | 514 | "id": "28", |
|
0 commit comments