-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a keras_recipe on using a keras callback to evaluate a
non-tensorflow metric
- Loading branch information
Showing
3 changed files
with
573 additions
and
0 deletions.
There are no files selected for viewing
223 changes: 223 additions & 0 deletions
223
examples/keras_recipes/ipynb/sklearn_metric_callbacks.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# Evaluationg and Exporting scikit-learn Metrics in a Keras callback\n", | ||
"\n", | ||
"**Author:** [lukewood](https://lukewood.xyz)<br>\n", | ||
"**Date created:** 10/07/2021<br>\n", | ||
"**Last modified:** 10/07/2021<br>\n", | ||
"**Description:** Example shows how to use Keras callbacks to evaluate and export non-TensorFlow based metrics." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"[Keras callbacks](https://keras.io/api/callbacks/) allow for the execution of arbitrary\n", | ||
"code at various stages of the Keras training process. While Keras offers first class\n", | ||
"support for metric evaluation, Keras [metrics](https://keras.io/api/metrics/) may only\n", | ||
"rely on TensorFlow code internally.\n", | ||
"\n", | ||
"While there are TensorFlow implementations of many metrics online, many metrics are\n", | ||
"implemented using [NumPy](https://numpy.org/) or another numerical computation library.\n", | ||
"By performing metric evaluation inside of a Keras callback, we can leverage any existing\n", | ||
"metric, and ultimately export the result to TensorBoard." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Jaccard Score Metric\n", | ||
"This example makes use of a sklearn metric,\n", | ||
"[`sklearn.metrics.jarrard_score`](https://scikit-learn.org/stable/modules/generated/sklear\n", | ||
"n.metrics.jaccard_score.html#sklearn.metrics.jaccard_score), and writes the result to a\n", | ||
"TensorBoard using the `tf.summary` API.\n", | ||
"\n", | ||
"This template can be modified slightly to make it work with any existing sklearn metric." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"from sklearn.metrics import jaccard_score\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"\n", | ||
"class JaccardScoreCallback(tf.keras.callbacks.Callback):\n", | ||
" \"\"\"Computes the jaccard score and logs the results to TensorBoard.\"\"\"\n", | ||
"\n", | ||
" def __init__(self, model_fn, x_test, y_test, summary_writer):\n", | ||
" self.model_fn = model_fn\n", | ||
" self.x_test = x_test\n", | ||
" self.y_test = y_test\n", | ||
" self.summary_writer = summary_writer\n", | ||
" self.keras_metric = tf.keras.metrics.Mean(\"jaccard_score\")\n", | ||
" self.epoch = 0\n", | ||
"\n", | ||
" def on_epoch_end(self, batch, logs=None):\n", | ||
" self.epoch += 1\n", | ||
" self.keras_metric.reset_state()\n", | ||
" predictions = self.model_fn(self.x_test)\n", | ||
" jaccard_value = jaccard_score(\n", | ||
" np.argmax(predictions, axis=-1), self.y_test, average=None\n", | ||
" )\n", | ||
" self.keras_metric.update_state(jaccard_value)\n", | ||
" self._write_metric()\n", | ||
"\n", | ||
" def _write_metric(self):\n", | ||
" with self.summary_writer.as_default():\n", | ||
" tf.summary.scalar(\n", | ||
" self.keras_metric.name,\n", | ||
" self.keras_metric.result().numpy().astype(float),\n", | ||
" step=self.epoch,\n", | ||
" )\n", | ||
" self.summary_writer.flush()\n", | ||
"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Sample Usage\n", | ||
"Let's test our `JaccardScoreCallback` class with a real keras model." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow.keras as keras\n", | ||
"\n", | ||
"# Model / data parameters\n", | ||
"num_classes = 10\n", | ||
"input_shape = (28, 28, 1)\n", | ||
"\n", | ||
"# the data, split between train and test sets\n", | ||
"(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", | ||
"\n", | ||
"# Scale images to the [0, 1] range\n", | ||
"x_train = x_train.astype(\"float32\") / 255\n", | ||
"x_test = x_test.astype(\"float32\") / 255\n", | ||
"# Make sure images have shape (28, 28, 1)\n", | ||
"x_train = np.expand_dims(x_train, -1)\n", | ||
"x_test = np.expand_dims(x_test, -1)\n", | ||
"print(\"x_train shape:\", x_train.shape)\n", | ||
"print(x_train.shape[0], \"train samples\")\n", | ||
"print(x_test.shape[0], \"test samples\")\n", | ||
"\n", | ||
"\n", | ||
"# convert class vectors to binary class matrices\n", | ||
"y_train = keras.utils.to_categorical(y_train, num_classes)\n", | ||
"y_test = keras.utils.to_categorical(y_test, num_classes)\n", | ||
"\n", | ||
"import tensorflow.keras.layers as layers\n", | ||
"\n", | ||
"model = keras.Sequential(\n", | ||
" [\n", | ||
" keras.Input(shape=input_shape),\n", | ||
" layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n", | ||
" layers.MaxPooling2D(pool_size=(2, 2)),\n", | ||
" layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n", | ||
" layers.MaxPooling2D(pool_size=(2, 2)),\n", | ||
" layers.Flatten(),\n", | ||
" layers.Dropout(0.5),\n", | ||
" layers.Dense(num_classes, activation=\"softmax\"),\n", | ||
" ]\n", | ||
")\n", | ||
"\n", | ||
"model.summary()\n", | ||
"\n", | ||
"batch_size = 128\n", | ||
"epochs = 15\n", | ||
"\n", | ||
"model.compile(loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n", | ||
"summary_writer = tf.summary.create_file_writer(\"logs/traditional_classifier\")\n", | ||
"callbacks = [\n", | ||
" JaccardScoreCallback(model, x_test, np.argmax(y_test, axis=-1), summary_writer)\n", | ||
"]\n", | ||
"model.fit(\n", | ||
" x_train,\n", | ||
" y_train,\n", | ||
" batch_size=batch_size,\n", | ||
" epochs=epochs,\n", | ||
" validation_split=0.1,\n", | ||
" callbacks=callbacks,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"If you now launch a TensorBoard instance using `tensorboard --logdir=logs`, you will now\n", | ||
"see the jaccard_score metric alongside any other exported metrics!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Conclusion\n", | ||
"Many ML practitioners and researchers rely on metrics that may not yet have a TensorFlow\n", | ||
"implementation. Keras users can still leverage the wide variety of existing metric\n", | ||
"implementations in other frameworks by using a Keras callback. These metrics can be\n", | ||
"exported, viewed and analyzed in the TensorBoard like any other metric." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"collapsed_sections": [], | ||
"name": "sklearn_metric_callbacks", | ||
"private_outputs": false, | ||
"provenance": [], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.