A PyTorch implementation of VITGAN: Training GANs with Vision Transformers
The implementation does not yet produce satisfactory results.
At the moment it suffers from mode collapse, which leads to an infinite cycle during training, hindering the learning process
You are welcome to examine the code and suggest improvements!
- Use vectorized L2 distance in attention for Discriminator
- Overlapping Image Patches
- DiffAugment
- Self-modulated LayerNorm (SLN)
- Implicit Neural Representation for Patch Generation
- ExponentialMovingAverage (EMA)
- Balanced Consistency Regularization (bCR)
- Improved Spectral Normalization (ISN)
- Python3
- einops
- pytorch_ema
- stylegan2-pytorch
- tensorboard
pip install einops
pip install git+https://github.com/fadel/pytorch_ema
pip install stylegan2-pytorch
pip install tensorboard
Train the model with the proposed parameters:
python train.py
Tensorboard
tensorboard --logdir runs/
The following parameters are the parameters, proposed in the paper for the CIFAR-10 dataset:
python train.py \
--image_size 32 \
--patch_size 4 \
--latent_dim 32 \
--hidden_features 384 \
--sln_paremeter_size 384 \
--depth 4 \
--num_heads 4 \
--combine_patch_embeddings false \
--batch_size 128 \
--device "cuda" \
--discriminator_type "stylegan2" \
--batch_size_history_discriminator false \
--lr 0.002 \
--epochs 200 \
--lambda_bCR_real 10 \
--lambda_bCR_fake 10 \
--lambda_lossD_noise 0 \
--lambda_lossD_history 0 \
The Generator follows the following architecture:
For debugging purposes, the Generator is separated into a Vision Transformer (ViT) model and a SIREN model.
Given a seed, the dimensionality of which is controlled by latent_dim
, the ViT model creates an embedding for each of the patches of the final image. Those embeddings are fed to a SIREN network, combined with a Fourier Position Encoding (Jupyter Notebook). It outputs the patches of the image, which are stitched together.
The ViT part of the Generator differs from a standard Vision Transformer in the following ways:
- The input to the Transformer consists only of the position embeddings
- Self-Modulated Layer Norm (SLN) is used in place of LayerNorm
- There is no classification head
SLN is the only place, where the seed is inputted to the network.
SLN consists of a regular LayerNorm, the result of which is multiplied by gamma
and added to beta
.
Both gamma
and beta
are calculated using a fully connected layer, different for each place, SLN is applied.
The input dimension to each of those fully connected is equal to hidden_dimension
and the output dimension can be either equal to hidden_dimension
or 1.
A description of SIREN: [Blog Post] [Paper] [Colab Notebook]
In contrast to regular SIREN, the desired output is not a single image. For this purpose, the patch embedding is combined to a position embedding.
The positional encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: (Jupyter Notebook)
In my implementation, the input to the SIREN is the sum of a patch embedding and a position embedding.
After examining the Generator, I believe that it is implemented correctly.
I found that there is no significant difference between using sln_paremeter_size=384
and sln_paremeter_size=1
.
The Discriminator follows the following architecture:
The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:
- DiffAugment
- Overlapping Image Patches
- Use vectorized L2 distance in attention for Discriminator
- Improved Spectral Normalization (ISN)
- Balanced Consistency Regularization (bCR)
For implementating DiffAugment, I used the code below:
[GitHub] [Paper]
Creation of the overlapping image patches is implemented with the use of a convolution layer.
[Paper]
The ISN implementation is based on the following implementation of Spectral Normalization:
[GitHub]
[Paper]
Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]
SIREN: Implicit Neural Representations with Periodic Activation Functions
Vision Transformer: [Blog Post]
L2 distance attention: The Lipschitz Constant of Self-Attention
Spectral Normalization reference code: [GitHub] [Paper]
Diff Augment: [GitHub] [Paper]
Fourier Position Embedding: [Jupyter Notebook]
Exponential Moving Average: [GitHub]
Balanced Concictancy Regularization (bCR): [Paper]
SyleGAN2 Discriminator: [GitHub]