forked from keras-team/keras-io
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* RealNVP We are three italian data scientists to be and we would love to know what you think if our project could be a good example! * Update and rename Density estimation using RealNVP to Density estimation using RealNVP.py Good evening! We have made the changes you requested, if there is any other mistake or if it is not well formatted please let us know. About the dataset, we have looked for a more realistic one for the density estimation task; we found a dataset with multiple clusters that we think might more accurately represent a real-life scenario. Best regards, Sanna Daniele, Zannini Quirini Giorgio, Mandolini Giorgio Maria * Copyediting * Fixed code Good evening, We have modified the code according to your suggestions. Should there be anything we need to change, please let us know. Best regards, Mandolini Giorgio Maria, Sanna Daniele, Zannini Quirini Giorgio * fetched old copyedit, changed self.call(x) to self(x) Sorry we somehow missed that copyedit. Here is the corrected version. * fixed some errors and changed dataset We have updated the code as requested. About your last question the batch size with the full length of the dataset and the abscence of a validation split were mistakes on our part we hope we fixed. Finally, we tested various datasets but we could not reproduce a better result than the double moon example shown in the paper: for this reason we have fallen back to the old dataset from sklearn. Is it okay? Maybe we should implement it from scratch without the help of the library? Please let us know. Best regards, Zannini Quirini Giorgio, Mandolini Giorgio Maria, Sanna Daniele * Fixed small mistake on commit Sorry that was an old version of the code, this should be the right one. * fixed the loss tracker Thanks for the suggestions which are helping us learning aspects of Keras we didn't know about! About the part of ```z, _ = model(normalized_data)```, we don't understand how to use ```predict``` here since it gives some weird result if we use it. * added missing comment on `metrics` Sorry about that! * Update Density estimation using RealNVP.py * self.call to self We are very sorry we had some trouble with the versions. It won't happen again. * Rename Density estimation using RealNVP.py to real_nvp.py * removed summary * deleted summary * fetched old copyedit * Copyedits * uploaded real_nvp.md * uploaded real_nvp.ipynb * uploaded real_nvp images Co-authored-by: François Chollet <francois.chollet@gmail.com>
- Loading branch information
1 parent
aa45363
commit 9c4a040
Showing
5 changed files
with
1,431 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,350 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# Density estimation using Real NVP\n", | ||
"\n", | ||
"**Authors:** [Mandolini Giorgio Maria](https://www.linkedin.com/in/giorgio-maria-mandolini-a2a1b71b4/), [Sanna Daniele](https://www.linkedin.com/in/daniele-sanna-338629bb/), [Zannini Quirini Giorgio](https://www.linkedin.com/in/giorgio-zannini-quirini-16ab181a0/)<br>\n", | ||
"**Date created:** 2020/08/10<br>\n", | ||
"**Last modified:** 2020/08/10<br>\n", | ||
"**Description:** Estimating the density distribution of the \"double moon\" dataset." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Introduction\n", | ||
"\n", | ||
"The aim of this work is to map a simple distribution - which is easy to sample\n", | ||
"and whose density is simple to estimate - to a more complex one learned from the data.\n", | ||
"This kind of generative model is also known as \"normalizing flow\".\n", | ||
"\n", | ||
"In order to do this, the model is trained via the maximum\n", | ||
"likelihood principle, using the \"change of variable\" formula.\n", | ||
"\n", | ||
"We will use an affine coupling function. We create it such that its inverse, as well as\n", | ||
"the determinant of the Jacobian, are easy to obtain (more details in the referenced paper).\n", | ||
"\n", | ||
"**Requirements:**\n", | ||
"\n", | ||
"* Tensorflow 2.3\n", | ||
"* Tensorflow probability 0.11.0\n", | ||
"\n", | ||
"**Reference:**\n", | ||
"\n", | ||
"[Density estimation using Real NVP](https://arxiv.org/pdf/1605.08803.pdf)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Setup" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow as tf\n", | ||
"from tensorflow import keras\n", | ||
"from tensorflow.keras import layers\n", | ||
"from tensorflow.keras import regularizers\n", | ||
"from sklearn.datasets import make_moons\n", | ||
"import numpy as np\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import tensorflow_probability as tfp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Load the data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"data = make_moons(3000, noise=0.05)[0].astype(\"float32\")\n", | ||
"norm = layers.experimental.preprocessing.Normalization()\n", | ||
"norm.adapt(data)\n", | ||
"normalized_data = norm(data)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Affine coupling layer" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Creating a custom layer with keras API.\n", | ||
"output_dim = 256\n", | ||
"reg = 0.01\n", | ||
"\n", | ||
"\n", | ||
"def Coupling(input_shape):\n", | ||
" input = keras.layers.Input(shape=input_shape)\n", | ||
"\n", | ||
" t_layer_1 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(input)\n", | ||
" t_layer_2 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(t_layer_1)\n", | ||
" t_layer_3 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(t_layer_2)\n", | ||
" t_layer_4 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(t_layer_3)\n", | ||
" t_layer_5 = keras.layers.Dense(\n", | ||
" input_shape, activation=\"linear\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(t_layer_4)\n", | ||
"\n", | ||
" s_layer_1 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(input)\n", | ||
" s_layer_2 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(s_layer_1)\n", | ||
" s_layer_3 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(s_layer_2)\n", | ||
" s_layer_4 = keras.layers.Dense(\n", | ||
" output_dim, activation=\"relu\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(s_layer_3)\n", | ||
" s_layer_5 = keras.layers.Dense(\n", | ||
" input_shape, activation=\"tanh\", kernel_regularizer=regularizers.l2(reg)\n", | ||
" )(s_layer_4)\n", | ||
"\n", | ||
" return keras.Model(inputs=input, outputs=[s_layer_5, t_layer_5])\n", | ||
"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Real NVP" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"class RealNVP(keras.Model):\n", | ||
" def __init__(self, num_coupling_layers):\n", | ||
" super(RealNVP, self).__init__()\n", | ||
"\n", | ||
" self.num_coupling_layers = num_coupling_layers\n", | ||
"\n", | ||
" # Distribution of the latent space.\n", | ||
" self.distribution = tfp.distributions.MultivariateNormalDiag(\n", | ||
" loc=[0.0, 0.0], scale_diag=[1.0, 1.0]\n", | ||
" )\n", | ||
" self.masks = np.array(\n", | ||
" [[0, 1], [1, 0]] * (num_coupling_layers // 2), dtype=\"float32\"\n", | ||
" )\n", | ||
" self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n", | ||
" self.layers_list = [Coupling(2) for i in range(num_coupling_layers)]\n", | ||
"\n", | ||
" @property\n", | ||
" def metrics(self):\n", | ||
" \"\"\"List of the model's metrics.\n", | ||
" We make sure the loss tracker is listed as part of `model.metrics`\n", | ||
" so that `fit()` and `evaluate()` are able to `reset()` the loss tracker\n", | ||
" at the start of each epoch and at the start of an `evaluate()` call.\n", | ||
" \"\"\"\n", | ||
" return [self.loss_tracker]\n", | ||
"\n", | ||
" def call(self, x, training=True):\n", | ||
" log_det_inv = 0\n", | ||
" direction = 1\n", | ||
" if training:\n", | ||
" direction = -1\n", | ||
" for i in range(self.num_coupling_layers)[::direction]:\n", | ||
" x_masked = x * self.masks[i]\n", | ||
" reversed_mask = 1 - self.masks[i]\n", | ||
" s, t = self.layers_list[i](x_masked)\n", | ||
" s *= reversed_mask\n", | ||
" t *= reversed_mask\n", | ||
" gate = (direction - 1) / 2\n", | ||
" x = (\n", | ||
" reversed_mask\n", | ||
" * (x * tf.exp(direction * s) + direction * t * tf.exp(gate * s))\n", | ||
" + x_masked\n", | ||
" )\n", | ||
" log_det_inv += gate * tf.reduce_sum(s, [1])\n", | ||
"\n", | ||
" return x, log_det_inv\n", | ||
"\n", | ||
" # Log likelihood of the normal distribution plus the log determinant of the jacobian.\n", | ||
"\n", | ||
" def log_loss(self, x):\n", | ||
" y, logdet = self(x)\n", | ||
" log_likelihood = self.distribution.log_prob(y) + logdet\n", | ||
" return -tf.reduce_mean(log_likelihood)\n", | ||
"\n", | ||
" def train_step(self, data):\n", | ||
" with tf.GradientTape() as tape:\n", | ||
"\n", | ||
" loss = self.log_loss(data)\n", | ||
"\n", | ||
" g = tape.gradient(loss, self.trainable_variables)\n", | ||
" self.optimizer.apply_gradients(zip(g, self.trainable_variables))\n", | ||
" self.loss_tracker.update_state(loss)\n", | ||
"\n", | ||
" return {\"loss\": self.loss_tracker.result()}\n", | ||
"\n", | ||
" def test_step(self, data):\n", | ||
" loss = self.log_loss(data)\n", | ||
" self.loss_tracker.update_state(loss)\n", | ||
"\n", | ||
" return {\"loss\": self.loss_tracker.result()}\n", | ||
"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Model training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"model = RealNVP(num_coupling_layers=6)\n", | ||
"\n", | ||
"model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0001))\n", | ||
"\n", | ||
"history = model.fit(\n", | ||
" normalized_data, batch_size=256, epochs=300, verbose=2, validation_split=0.2\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Performance evaluation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"plt.figure(figsize=(15, 10))\n", | ||
"plt.plot(history.history[\"loss\"])\n", | ||
"plt.plot(history.history[\"val_loss\"])\n", | ||
"plt.title(\"model loss\")\n", | ||
"plt.legend([\"train\", \"validation\"], loc=\"upper right\")\n", | ||
"plt.ylabel(\"loss\")\n", | ||
"plt.xlabel(\"epoch\")\n", | ||
"\n", | ||
"# From data to latent space.\n", | ||
"z, _ = model(normalized_data)\n", | ||
"\n", | ||
"# From latent space to data.\n", | ||
"samples = model.distribution.sample(3000)\n", | ||
"x, _ = model.predict(samples)\n", | ||
"\n", | ||
"f, axes = plt.subplots(2, 2)\n", | ||
"f.set_size_inches(20, 15)\n", | ||
"\n", | ||
"axes[0, 0].scatter(normalized_data[:, 0], normalized_data[:, 1], color=\"r\")\n", | ||
"axes[0, 0].set(title=\"Inference data space X\", xlabel=\"x\", ylabel=\"y\")\n", | ||
"axes[0, 1].scatter(z[:, 0], z[:, 1], color=\"r\")\n", | ||
"axes[0, 1].set(title=\"Inference latent space Z\", xlabel=\"x\", ylabel=\"y\")\n", | ||
"axes[0, 1].set_xlim([-3.5, 4])\n", | ||
"axes[0, 1].set_ylim([-4, 4])\n", | ||
"axes[1, 0].scatter(samples[:, 0], samples[:, 1], color=\"g\")\n", | ||
"axes[1, 0].set(title=\"Generated latent space Z\", xlabel=\"x\", ylabel=\"y\")\n", | ||
"axes[1, 1].scatter(x[:, 0], x[:, 1], color=\"g\")\n", | ||
"axes[1, 1].set(title=\"Generated data space X\", label=\"x\", ylabel=\"y\")\n", | ||
"axes[1, 1].set_xlim([-2, 2])\n", | ||
"axes[1, 1].set_ylim([-2, 2])" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"collapsed_sections": [], | ||
"name": "real_nvp", | ||
"private_outputs": false, | ||
"provenance": [], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.