With more code relatability/continuity, this notebook shows a basic implementation of a transformer (decoder) architecture for image generation in TensorFlow 2.
It demonstrates how to use a transformer decoder to learn a generative representation of the MNIST dataset and perform an autoregressive image reconstruction.
Mnist dataset examples:
To reduce the number of color values, we perform a color quantization, e.g. we compute k-means clustering to get 8 color clusters and thus reduce our color palette.
Quantized examples:
Afterwards we serialize the images to obtain linear sequences of length 784 per image, which can be fed into the model as used in NLP.
See the notebook to get an in-depth explanation of the model.
We perform image reconstruction, e.g. we take mnist images, remove the bottom half of the image, quantize it and let our model reconstruct the missing part. Afterwards we can revert the quantization and obtain a new generated mnist image. Compare the output and the input to see that the model does not memorize the inputs but creates new images.
Input data:
Bottom half removed:
Generated output:
-
Transformers Tutorial - In depth tutorial on transformers in TF2.
-
Illustrated Transformers Guide - Quick and intuitive explanation of transformers.
-
Image GPT Blog - original ImageGPT by Chen et al.
-
ImageGPT in PyTorch - an implementation of ImageGPT for PyTorch.
Autoregressive Image Generation, MNIST, Transformers, Transformer Decoder, ImageGPT, Generative Methods, Generative Loss, Deep Learning, Machine Learning, TensorFlow 2