This repository contains code of FCGAN which is tested and trained on MNIST and CIFAR-10 datasets. It is based on Pytorch framework.
GANs are generally made up of two models: The Artist (Generator) and The Critic (Discriminator). The generator creates an image from random noise, and the discriminator evaluates the generated image with the images in the given dataset. We train the models by minimaxing the costs of the models. The generator tries to fool the discriminator by producing realistic looking images, and the discriminator becomes better in understanding the difference between real and fake images. This two player game improves the model until the generator produces realistic images or the system reaches nash equilibrium.
- Setup Instructions and Dependencies
- Training Model from Scratch
- Generating Images from Trained Models
- Model Architecture
- Repository Overview
- Results Obtained
- Observations
- Credits
You may setup the repository on your local machine by either downloading it or running the following line on terminal
.
git clone https://github.com/h3lio5/gan-pytorch.git
The trained models are large in size and hence their Google Drive links are provided in the model.txt
file.
The data required for training is automatically downloaded when running train.py
.
All dependencies required by this repo can be downloaded by creating a virtual environment with Python 3.7 and running
pip install -r requirements.txt
Make sure to have CUDA 10.0.130 and cuDNN 7.6.0 installed in the virtual environment. For a conda environment, this can be done by using the following commands:
conda install cudatoolkit=10.0
conda install cudnn=7.6.0
To train your own model from scratch, run
python train.py -config path/to/config.ini
- The parameters for your experiment are all set by defualt. But you are free to set them on your own.
- The training script will create a folder exp_name as specified in your
config.ini
file. - This folder will contain all data related to your experiment such as tensorboard logs, images generated during training and training checkpoints.
To generate images from trained models, run
python generate.py --dataset mnist/cifar-10 --load_path path/to/checkpoint --grid_size n --save_path directory/where/images/are/saved
The arguments used are explained as follows
--dataset
requires eithermnist
orcifar10
according to what dataset the model was trained on.--load_path
requires the path to the training checkpoint to load. Point this towards the *.index file without the extension. For example-load_path training_checkpoints/ckpt-1
.--grid_size
requires integern
and will generate n*n images in a grid.--save_path
requires the path to the directory where the generated images will be saved. If the directory doesn't exist, the script will create it.
To generate images from pre-trained models, download checkpoint files from the Google Drive link given in the model.txt
file.
MNIST
: The generator model is a 5-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer.CIFAR10
: The generator model is a 6-layer MLP with LeakyReLU activation function followed by a Tahn non-linearity in the final layer.- Input is a 100-dimensional noise. It is passed through the network to produce either a 28x28x1 (MNIST) or 32x32x3 (CIFAR-10) image.
MNIST
: The discriminator model is a 3-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer.CIFAR10
: The discriminator model is a 4-layer MLP with LeakyReLU activation function followed by a Sigmoid non-linearity in the final layer.- Output is a single number which tells if the image is real or fake/generated.
This repository contains the following files and folders
-
experiments: This folder contains data for different runs.
-
resources: Contains media for
readme.md
. -
data_loader.py
: Contains helper functions that load and preprocess data. -
generate.py
: Used to generate and save images from trained models. -
model.py
: Contains helper functions that create generator and discriminator models. -
model.txt
: Contains google drive links to trained models. -
requirements.txt
: Lists dependencies for easy setup in virtual environments. -
train.py
: Contains code to train models from scratch.
Samples generated after training model for 100 epochs on MNIST.
Samples generated after training model for 200 epochs on CIFAR-10.
- Optimizer used is Adam
- Learning rate 0.0002, beta-1 0.5
- Trained for 100 epochs (MNIST) and 100 epochs (CIFAR10)
- Batch size is 128 for both (MNIST) and (CIFAR10)
- The model uses label flipping (i.e. real images are assigned 0 and fake images are assigned 1)
The model took around 12 minutes to train for 100 epochs on the gpu. The generated images are not that sharp but somewhat resemble the real data. The model is also prone to mode collapse.
Training for long duration (150+ epochs) does not seem to improve the model's performance and sometimes even deteriorates the quality of the images produced.
Training on the CIFAR10 dataset was challenging. The dataset was varied and the network has a higher number of parameters to train. The model was trained for 200 epochs and took about 30 minutes to train.
However the main problem faced by me was observing 32x32 images and evaluating if they were 'good enough'. The images are too low-resolution to properly understand the subject but they are easily passable since they look quite similar to the real data.
Some images have noise but most images don't have much artifacts in them. This is partly due to the network training on all of the 10 labels of the CIFAR-10 dataset. Better results could be obtained by only training the network on one particular label at a time but this takes away the robustness of the model.
To make this repository I referenced multiple sources: