forked from keras-team/keras-io
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added PixelCNN example, added pathlib to requirements (keras-team#45)
* added PixelCNN example, added pathlib to requirements * reformatted * attempt to compile to website * added example * example done * necessary changes * fixed formatting * added cleaned example * cleanup * Style nits * issue resolution * issue resolution * added cleaned example * full functionality * remerged nits * fixed formatting * fixed formatting * blackened file * Style nits * removed redundant relu * added cleaned example * cleanup * added cleaned example * cleanup Co-authored-by: François Chollet <francois.chollet@gmail.com>
- Loading branch information
Showing
18 changed files
with
3,757 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"# PixelCNN\n", | ||
"\n", | ||
"**Author:** [ADMoreau](https://github.com/ADMoreau)<br>\n", | ||
"**Date created:** 2020/05/17<br>\n", | ||
"**Last modified:** 2020/05/23<br>\n", | ||
"**Description:** PixelCNN implemented in Keras." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Introduction\n", | ||
"\n", | ||
"PixelCNN is a generative model proposed in 2016 by van den Oord et al.\n", | ||
"(reference: [https://arxiv.org/abs/1606.05328](https://arxiv.org/abs/1606.05328)).\n", | ||
"It is designed to generate images (or other data types) iteratively,\n", | ||
"from an input vector where the probability distribution of prior elements dictates the\n", | ||
"probability distribution of later elements. In the following example, images are generated\n", | ||
"in this fashion, pixel-by-pixel, via a masked convolution kernel that only looks at data\n", | ||
"from previously generated pixels (origin at the top left) to generate later pixels.\n", | ||
"During inference, the output of the network is used as a probability ditribution\n", | ||
"from which new pixel values are sampled to generate a new image\n", | ||
"(here, with MNIST, the pixels values are either black or white).\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import tensorflow as tf\n", | ||
"from tensorflow import keras\n", | ||
"from tensorflow.keras import layers\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Getting the Data\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Model / data parameters\n", | ||
"num_classes = 10\n", | ||
"input_shape = (28, 28, 1)\n", | ||
"n_residual_blocks = 5\n", | ||
"# The data, split between train and test sets\n", | ||
"(x, _), (y, _) = keras.datasets.mnist.load_data()\n", | ||
"# Concatenate all of the images together\n", | ||
"data = np.concatenate((x, y), axis=0)\n", | ||
"# Round all pixel values less than 33% of the max 256 value to 0\n", | ||
"# anything above this value gets rounded up to 1 so that all values are either\n", | ||
"# 0 or 1\n", | ||
"data = np.where(data < (0.33 * 256), 0, 1)\n", | ||
"data = data.astype(np.float32)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Create two classes for the requisite Layers for the model\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# The first layer is the PixelCNN layer. This layer simply\n", | ||
"# builds on the 2D convolutional layer, but includes masking.\n", | ||
"class PixelConvLayer(layers.Layer):\n", | ||
" def __init__(self, mask_type, **kwargs):\n", | ||
" super(PixelConvLayer, self).__init__()\n", | ||
" self.mask_type = mask_type\n", | ||
" self.conv = layers.Conv2D(**kwargs)\n", | ||
"\n", | ||
" def build(self, input_shape):\n", | ||
" # Build the conv2d layer to initialize kernel variables\n", | ||
" self.conv.build(input_shape)\n", | ||
" # Use the initialized kernel to create the mask\n", | ||
" kernel_shape = self.conv.kernel.get_shape()\n", | ||
" self.mask = np.zeros(shape=kernel_shape)\n", | ||
" self.mask[: kernel_shape[0] // 2, ...] = 1.0\n", | ||
" self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0\n", | ||
" if self.mask_type == \"B\":\n", | ||
" self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0\n", | ||
"\n", | ||
" def call(self, inputs):\n", | ||
" self.conv.kernel.assign(self.conv.kernel * self.mask)\n", | ||
" return self.conv(inputs)\n", | ||
"\n", | ||
"\n", | ||
"# Next, we build our residual block layer.\n", | ||
"# This is just a normal residual block, but based on the PixelConvLayer.\n", | ||
"class ResidualBlock(keras.layers.Layer):\n", | ||
" def __init__(self, filters, **kwargs):\n", | ||
" super(ResidualBlock, self).__init__(**kwargs)\n", | ||
" self.conv1 = keras.layers.Conv2D(\n", | ||
" filters=filters, kernel_size=1, activation=\"relu\"\n", | ||
" )\n", | ||
" self.pixel_conv = PixelConvLayer(\n", | ||
" mask_type=\"B\",\n", | ||
" filters=filters // 2,\n", | ||
" kernel_size=3,\n", | ||
" activation=\"relu\",\n", | ||
" padding=\"same\",\n", | ||
" )\n", | ||
" self.conv2 = keras.layers.Conv2D(\n", | ||
" filters=filters, kernel_size=1, activation=\"relu\"\n", | ||
" )\n", | ||
"\n", | ||
" def call(self, inputs):\n", | ||
" x = self.conv1(inputs)\n", | ||
" x = self.pixel_conv(x)\n", | ||
" x = self.conv2(x)\n", | ||
" return keras.layers.add([inputs, x])\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Build the model based on the original paper\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"inputs = keras.Input(shape=input_shape)\n", | ||
"x = PixelConvLayer(\n", | ||
" mask_type=\"A\", filters=128, kernel_size=7, activation=\"relu\", padding=\"same\"\n", | ||
")(inputs)\n", | ||
"\n", | ||
"for _ in range(n_residual_blocks):\n", | ||
" x = ResidualBlock(filters=128)(x)\n", | ||
"\n", | ||
"for _ in range(2):\n", | ||
" x = PixelConvLayer(\n", | ||
" mask_type=\"B\",\n", | ||
" filters=128,\n", | ||
" kernel_size=1,\n", | ||
" strides=1,\n", | ||
" activation=\"relu\",\n", | ||
" padding=\"valid\",\n", | ||
" )(x)\n", | ||
"\n", | ||
"out = keras.layers.Conv2D(\n", | ||
" filters=1, kernel_size=1, strides=1, activation=\"sigmoid\", padding=\"valid\"\n", | ||
")(x)\n", | ||
"\n", | ||
"pixel_cnn = keras.Model(inputs, out)\n", | ||
"adam = keras.optimizers.Adam(learning_rate=0.0001)\n", | ||
"pixel_cnn.compile(optimizer=adam, loss=\"binary_crossentropy\")\n", | ||
"\n", | ||
"pixel_cnn.summary()\n", | ||
"pixel_cnn.fit(x=data, y=data, batch_size=64, epochs=50, validation_split=0.1)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"colab_type": "text" | ||
}, | ||
"source": [ | ||
"## Demonstration\n", | ||
"\n", | ||
"The PixelCNN cannot generate the full image at once, and must instead generate each pixel in\n", | ||
"order, append the last generated pixel to the current image, and feed the image back into the\n", | ||
"model to repeat the process.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 0, | ||
"metadata": { | ||
"colab_type": "code" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from IPython.display import Image, display\n", | ||
"from tqdm import tqdm\n", | ||
"import tensorflow_probability as tfp\n", | ||
"\n", | ||
"# Create an empty array of pixels.\n", | ||
"batch = 4\n", | ||
"pixels = np.zeros(shape=(batch,) + (pixel_cnn.input_shape)[1:])\n", | ||
"batch, rows, cols, channels = pixels.shape\n", | ||
"\n", | ||
"# Iterate the pixels because generation has to be done sequentially pixel by pixel.\n", | ||
"for row in tqdm(range(rows)):\n", | ||
" for col in range(cols):\n", | ||
" for channel in range(channels):\n", | ||
" # Feed the whole array and retrieving the pixel value probabilities for the next\n", | ||
" # pixel.\n", | ||
" probs = pixel_cnn.predict(pixels)[:, row, col, channel]\n", | ||
" # Use the probabilities to pick pixel values and append the values to the image\n", | ||
" # frame.\n", | ||
" pixels[:, row, col, channel] = tfp.distributions.Bernoulli(\n", | ||
" probs=probs\n", | ||
" ).sample()\n", | ||
"\n", | ||
"\n", | ||
"def deprocess_image(x):\n", | ||
" # Stack the single channeled black and white image to rgb values.\n", | ||
" x = np.stack((x, x, x), 2)\n", | ||
" # Undo preprocessing\n", | ||
" x *= 255.0\n", | ||
" # Convert to uint8 and clip to the valid range [0, 255]\n", | ||
" x = np.clip(x, 0, 255).astype(\"uint8\")\n", | ||
" return x\n", | ||
"\n", | ||
"\n", | ||
"# Iterate the generated images and plot them with matplotlib.\n", | ||
"for i, pic in enumerate(pixels):\n", | ||
" keras.preprocessing.image.save_img(\n", | ||
" \"generated_image_{}.png\".format(i), deprocess_image(np.squeeze(pic, -1))\n", | ||
" )\n", | ||
"\n", | ||
"display(Image(\"generated_image_0.png\"))\n", | ||
"display(Image(\"generated_image_1.png\"))\n", | ||
"display(Image(\"generated_image_2.png\"))\n", | ||
"display(Image(\"generated_image_3.png\"))\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"collapsed_sections": [], | ||
"name": "pixelcnn", | ||
"private_outputs": false, | ||
"provenance": [], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.