Skip to content

Commit

Permalink
Revise: redactional edits on VQ-VAE example (keras-team#938)
Browse files Browse the repository at this point in the history
* explain PixelCNN shape

Signed-off-by: reinvantveer <rein@vantveer.me>

* update metadata, attach extra author

Signed-off-by: reinvantveer <rein@vantveer.me>

* fix explanation error

The code book _is_ optimized during training of the VQVAE first stage and remains unchanged during the PixelCNN training.

Signed-off-by: reinvantveer <rein@vantveer.me>

* small textual changes

Signed-off-by: reinvantveer <rein@vantveer.me>

* reword: the encoder outputs sizes equal to the number of filters.

Signed-off-by: reinvantveer <rein@vantveer.me>

* more concise

Signed-off-by: reinvantveer <rein@vantveer.me>

* drop redundant parentheses, move comment

Signed-off-by: reinvantveer <rein@vantveer.me>

* reword comments

The purpose of `input_size` is to be able to reshape the quantized output back to its original shape

Signed-off-by: reinvantveer <rein@vantveer.me>

* add explicative comment

Signed-off-by: reinvantveer <rein@vantveer.me>

* indent/format codebook loss like commitment loss statement to deflummox the reader

Signed-off-by: reinvantveer <rein@vantveer.me>

* more concise

Signed-off-by: reinvantveer <rein@vantveer.me>

* add caveat on relu-activated encoder and decoder layers

Signed-off-by: reinvantveer <rein@vantveer.me>

* more concise, re-word explanation

Signed-off-by: reinvantveer <rein@vantveer.me>

* whitespace

Signed-off-by: reinvantveer <rein@vantveer.me>

* formatting

Signed-off-by: reinvantveer <rein@vantveer.me>

* re-word, better shape clarification

Signed-off-by: reinvantveer <rein@vantveer.me>

* revert co-authorship

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* Revert "reword comments"

This reverts commit a2e2b45.

* unfold multi-line comment

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix accidental tabs instead of spaces

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* be more helpful: show what should've been formatted differently

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* better readable loss calculations whilst keeping style checker happy

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* whitespace

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* update trade-off description for quantizer output size and code book size

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* fold lines for readability

Signed-off-by: reinvantveer <reinvantveer@gmail.com>

* add contributor notes

Signed-off-by: reinvantveer <rein@vantveer.me>

* mirror edits from ../vq_vay.py

Signed-off-by: reinvantveer <rein@vantveer.me>

* update to keras-cv latest version

Signed-off-by: reinvantveer <rein@vantveer.me>

* update to keras-nlp latest version

Signed-off-by: reinvantveer <rein@vantveer.me>

* mirror edits from ../vq_vae.py

Signed-off-by: reinvantveer <rein@vantveer.me>

* fix unescaped double quotes

Signed-off-by: reinvantveer <rein@vantveer.me>

* fix unescaped double quotes

Signed-off-by: reinvantveer <rein@vantveer.me>

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
reinvantveer and sayakpaul authored Jul 5, 2022
1 parent a9b0bc5 commit b223b4a
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 123 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
run: |
pip install --upgrade pip
pip install black==22.1.0
black --check ./examples
black --check --diff ./examples
docker-image:
runs-on: ubuntu-latest
steps:
Expand Down
114 changes: 72 additions & 42 deletions examples/generative/ipynb/vq_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Sayak Paul](https://twitter.com/RisingSayak)<br>\n",
"**Date created:** 2021/07/21<br>\n",
"**Last modified:** 2021/07/21<br>\n",
"**Last modified:** 2022/06/27<br>\n",
"**Description:** Training a VQ-VAE for image reconstruction and codebook sampling for generation."
]
},
Expand All @@ -20,10 +20,10 @@
"colab_type": "text"
},
"source": [
"In this example, we will develop a Vector Quantized Variational Autoencoder (VQ-VAE).\n",
"In this example, we develop a Vector Quantized Variational Autoencoder (VQ-VAE).\n",
"VQ-VAE was proposed in\n",
"[Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937)\n",
"by van der Oord et al. In traditional VAEs, the latent space is continuous and is sampled\n",
"by van der Oord et al. In standard VAEs, the latent space is continuous and is sampled\n",
"from a Gaussian distribution. It is generally harder to learn such a continuous\n",
"distribution via gradient descent. VQ-VAEs, on the other hand,\n",
"operate on a discrete latent space, making the optimization problem simpler. It does so\n",
Expand All @@ -32,16 +32,20 @@
"outputs. These discrete code words are then fed to the decoder, which is trained\n",
"to generate reconstructed samples.\n",
"\n",
"For a detailed overview of VQ-VAEs, please refer to the original paper and\n",
"For an overview of VQ-VAEs, please refer to the original paper and\n",
"[this video explanation](https://www.youtube.com/watch?v=VZFVUrYcig0).\n",
"If you need a refresher on VAEs, you can refer to\n",
"[this book chapter](https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-12/).\n",
"VQ-VAEs are one of the main recipes behind [DALL-E](https://openai.com/blog/dall-e/)\n",
"and the idea of a codebook is used in [VQ-GANs](https://arxiv.org/abs/2012.09841).\n",
"\n",
"This example uses references from the\n",
"This example uses implementation details from the\n",
"[official VQ-VAE tutorial](https://github.com/deepmind/sonnet/blob/master/sonnet/examples/vqvae_example.ipynb)\n",
"from DeepMind. To run this example, you will need TensorFlow 2.5 or higher, as well as\n",
"from DeepMind.",
"\n",
"## Requirements",
"\n",
"To run this example, you will need TensorFlow 2.5 or higher, as well as\n",
"TensorFlow Probability, which can be installed using the command below."
]
},
Expand Down Expand Up @@ -90,13 +94,12 @@
"source": [
"## `VectorQuantizer` layer\n",
"\n",
"Here, we will implement a custom layer to encapsulate the vector\n",
"quantizer logic, which is the central component of VQ-VAEs.\n",
"Consider an output from the encoder, with shape `(batch_size, height, width, num_channels)`.\n",
"The vector quantizer will first\n",
"flatten this output, only keeping the `num_channels` dimension intact. So, the shape would\n",
"become `(batch_size * height * width, num_channels)`. The rationale behind this is to\n",
"treat the total number of channels as the space for the latent embeddings.\n",
"First, we implement a custom layer for the vector quantizer, which is the layer in between\n",
"the encoder and decoder. Consider an output from the encoder, with shape `(batch_size, height, width,\n",
"num_filters)`. The vector quantizer will first flatten this output, only keeping the\n",
"`num_filters` dimension intact. So, the shape would become `(batch_size * height * width,\n",
"num_filters)`. The rationale behind this is to treat the total number of filters as the size for\n",
"the latent embeddings.\n",
"\n",
"An embedding table is then initialized to learn a codebook. We measure the L2-normalized\n",
"distance between the flattened encoder outputs and code words of this codebook. We take the\n",
Expand All @@ -107,8 +110,8 @@
"Since the quantization process is not differentiable, we apply a\n",
"[straight-through estimator](https://www.hassanaskary.com/python/pytorch/deep%20learning/2020/09/19/intuitive-explanation-of-straight-through-estimators.html)\n",
"in between the decoder and the encoder, so that the decoder gradients are directly propagated\n",
"to the encoder. As the encoder and decoder share the same channel space, the hope is that the\n",
"decoder gradients will still be meaningful to the encoder."
"to the encoder. As the encoder and decoder share the same channel space, the decoder gradients are\n",
"still meaningful to the encoder."
]
},
{
Expand All @@ -125,9 +128,9 @@
" super().__init__(**kwargs)\n",
" self.embedding_dim = embedding_dim\n",
" self.num_embeddings = num_embeddings\n",
" self.beta = (\n",
" beta # This parameter is best kept between [0.25, 2] as per the paper.\n",
" )\n",
"\n",
" # The `beta` parameter is best kept between [0.25, 2] as per the paper.\n",
" self.beta = beta\n",
"\n",
" # Initialize the embeddings which we will quantize.\n",
" w_init = tf.random_uniform_initializer()\n",
Expand All @@ -149,17 +152,17 @@
" encoding_indices = self.get_code_indices(flattened)\n",
" encodings = tf.one_hot(encoding_indices, self.num_embeddings)\n",
" quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)\n",
"\n",
" # Reshape the quantized values back to the original input shape\n",
" quantized = tf.reshape(quantized, input_shape)\n",
"\n",
" # Calculate vector quantization loss and add that to the layer. You can learn more\n",
" # about adding losses to different layers here:\n",
" # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check\n",
" # the original paper to get a handle on the formulation of the loss function.\n",
" commitment_loss = self.beta * tf.reduce_mean(\n",
" (tf.stop_gradient(quantized) - x) ** 2\n",
" )\n",
" commitment_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)\n",
" codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)\n",
" self.add_loss(commitment_loss + codebook_loss)\n",
" self.add_loss(self.beta * commitment_loss + codebook_loss)\n",
"\n",
" # Straight-through estimator.\n",
" quantized = x + tf.stop_gradient(quantized - x)\n",
Expand Down Expand Up @@ -203,10 +206,14 @@
"source": [
"## Encoder and decoder\n",
"\n",
"We will now implement the encoder and the decoder for the VQ-VAE. We will keep them small so\n",
"that their capacity is a good fit for the MNIST dataset, which we will use to demonstrate\n",
"the results. The definitions of the encoder and decoder come from\n",
"[this example](https://keras.io/examples/generative/vae)."
"Now for the encoder and the decoder for the VQ-VAE. We will keep them small so\n",
"that their capacity is a good fit for the MNIST dataset. The implementation of the encoder and\n",
"come from\n",
"[this example](https://keras.io/examples/generative/vae).\n",
"\n",
"Note that activations _other than ReLU_ may not work for the encoder and decoder layers in the\n",
"quantization architecture: Leaky ReLU activated layers, for example, have proven difficult to\n",
"train, resulting in intermittent loss spikes that the model has trouble recovering from."
]
},
{
Expand Down Expand Up @@ -497,21 +504,20 @@
},
"source": [
"The figure above shows that the discrete codes have been able to capture some\n",
"regularities from the dataset. Now, you might wonder, ***how do we use these codes to\n",
"generate new samples?*** Specifically, how do we sample from this codebook to create\n",
"novel examples? Since these codes are discrete and we imposed a categorical distribution\n",
"on them, we cannot use them yet to generate anything meaningful. These codes were not\n",
"updated during the training process as well. So, they need to be adjusted further so that\n",
"we can use for them the subsequent image generation task. The authors use a PixelCNN to\n",
"train these codes so that they can be used as powerful priors to generate novel examples.\n",
"\n",
"PixelCNN was proposed in\n",
"regularities from the dataset. Now, how do we sample from this codebook to create\n",
"novel images? Since these codes are discrete and we imposed a categorical distribution\n",
"on them, we cannot use them yet to generate anything meaningful until we can generate likely\n",
"sequences of codes that we can give to the decoder.\n",
"\n",
"The authors use a PixelCNN to train these codes so that they can be used as powerful priors to\n",
"generate novel examples. PixelCNN was proposed in\n",
"[Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328)\n",
"by van der Oord et al. We will borrow code from\n",
"[this example](https://keras.io/examples/generative/pixelcnn/)\n",
"to develop a PixelCNN. It's an auto-regressive generative model where the current outputs\n",
"are conditioned on the prior ones. In other words, a PixelCNN generates an image on a\n",
"pixel-by-pixel basis."
"by van der Oord et al. We borrow the implementation from\n",
"[this PixelCNN example](https://keras.io/examples/generative/pixelcnn/). It's an auto-regressive\n",
"generative model where the outputs are conditional on the prior ones. In other words, a PixelCNN\n",
"generates an image on a pixel-by-pixel basis. For the purpose in this example, however, its task\n",
"is to generate code book indices instead of pixels directly. The trained VQ-VAE decoder is used\n",
"to map the indices generated by the PixelCNN back into the pixel space."
]
},
{
Expand Down Expand Up @@ -543,7 +549,27 @@
"colab_type": "text"
},
"source": [
"Don't worry about the input shape. It'll become clear in the following sections."
"This input shape represents the reduction in the resolution performed by the encoder. With \"same\" padding,\n",
"this exactly halves the \"resolution\" of the output shape for each stride-2 convolution layer. So, with these\n",
"two layers, we end up with an encoder output tensor of 7x7 on axes 2 and 3, with the first axis as the batch\n",
"size and the last axis being the code book embedding size. Since the quantization layer in the autoencoder\n",
"maps these 7x7 tensors to indices of the code book, these output layer axis sizes must be matched by the\n",
"PixelCNN as the input shape. The task of the PixelCNN for this architecture is to generate _likely_ 7x7\n",
"arrangements of codebook indices.\n",
"\n",
"Note that this shape is something to optimize for in larger-sized image domains, along with the code\n",
"book sizes. Since the PixelCNN is autoregressive, it needs to pass over each codebook index sequentially\n",
"in order to generate novel images from the codebook. Each stride-2 (or rather more correctly a\n",
"stride (2, 2)) convolution layer will divide the image generation time by four. Note, however, that there\n",
"is probably a lower bound on this part: when the number of codes for the image to reconstruct is too small,\n",
"it has insufficient information for the decoder to represent the level of detail in the image, so the\n",
"output quality will suffer. This can be amended at least to some extent by using a larger code book.\n",
"Since the autoregressive part of the image generation procedure uses codebook indices, there is far less of\n",
"a performance penalty on using a larger code book as the lookup time for a larger-sized code from a larger\n",
"code book is much smaller in comparison to iterating over a larger sequence of code book indices, although\n",
"the size of the code book does impact on the batch size that can pass through the image generation procedure.\n",
"Finding the sweet spot for this trade-off can require some architecture tweaking and could very well differ\n",
"per dataset."
]
},
{
Expand All @@ -555,7 +581,11 @@
"## PixelCNN model\n",
"\n",
"Majority of this comes from\n",
"[this example](https://keras.io/examples/generative/pixelcnn/)."
"[this example](https://keras.io/examples/generative/pixelcnn/).\n",
"\n",
"## Notes\n",
"Thanks to [Rein van 't Veer](https://github.com/reinvantveer) for improving this example with\n",
"copy-edits and minor code clean-ups."
]
},
{
Expand Down
Loading

0 comments on commit b223b4a

Please sign in to comment.