Skip to content

A PyTorch implementation of VITGAN: Training GANs with Vision Transformers

License

Notifications You must be signed in to change notification settings

teodorToshkov/ViTGAN-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

ViTGAN-pytorch

A PyTorch implementation of VITGAN: Training GANs with Vision Transformers

Open In Colab

The implementation does not yet produce satisfactory results.
At the moment it suffers from mode collapse, which leads to an infinite cycle during training, hindering the learning process

You are welcome to examine the code and suggest improvements!

TODO:

  1. Use vectorized L2 distance in attention for Discriminator
  2. Overlapping Image Patches
  3. DiffAugment
  4. Self-modulated LayerNorm (SLN)
  5. Implicit Neural Representation for Patch Generation
  6. ExponentialMovingAverage (EMA)
  7. Balanced Consistency Regularization (bCR)
  8. Improved Spectral Normalization (ISN)

Dependencies

  • Python3
  • einops
  • pytorch_ema
  • stylegan2-pytorch
  • tensorboard
pip install einops
pip install git+https://github.com/fadel/pytorch_ema
pip install stylegan2-pytorch
pip install tensorboard

TLDR:

Train the model with the proposed parameters:

python train.py

Tensorboard

tensorboard --logdir runs/

The following parameters are the parameters, proposed in the paper for the CIFAR-10 dataset:

python train.py \
--image_size 32 \
--patch_size 4 \
--latent_dim 32 \
--hidden_features 384 \
--sln_paremeter_size 384 \
--depth 4 \
--num_heads 4 \
--combine_patch_embeddings false \
--batch_size 128 \
--device "cuda" \
--discriminator_type "stylegan2" \
--batch_size_history_discriminator false \
--lr 0.002 \
--epochs 200 \
--lambda_bCR_real 10 \
--lambda_bCR_fake 10 \
--lambda_lossD_noise 0 \
--lambda_lossD_history 0 \

Implementation Details

Generator

The Generator follows the following architecture:

ViTGAN Generator architecture

For debugging purposes, the Generator is separated into a Vision Transformer (ViT) model and a SIREN model.

Given a seed, the dimensionality of which is controlled by latent_dim, the ViT model creates an embedding for each of the patches of the final image. Those embeddings are fed to a SIREN network, combined with a Fourier Position Encoding (Jupyter Notebook). It outputs the patches of the image, which are stitched together.

The ViT part of the Generator differs from a standard Vision Transformer in the following ways:

  • The input to the Transformer consists only of the position embeddings
  • Self-Modulated Layer Norm (SLN) is used in place of LayerNorm
  • There is no classification head

SLN is the only place, where the seed is inputted to the network.
SLN consists of a regular LayerNorm, the result of which is multiplied by gamma and added to beta.
Both gamma and beta are calculated using a fully connected layer, different for each place, SLN is applied.
The input dimension to each of those fully connected is equal to hidden_dimension and the output dimension can be either equal to hidden_dimension or 1.

SIREN

A description of SIREN: [Blog Post] [Paper] [Colab Notebook]

In contrast to regular SIREN, the desired output is not a single image. For this purpose, the patch embedding is combined to a position embedding.

The positional encoding, used in ViTGAN is the Fourier Position Encoding, the code for which was taken from here: (Jupyter Notebook)

In my implementation, the input to the SIREN is the sum of a patch embedding and a position embedding.


After examining the Generator, I believe that it is implemented correctly.

I found that there is no significant difference between using sln_paremeter_size=384 and sln_paremeter_size=1.

Discriminator

The Discriminator follows the following architecture:

ViTGAN Discriminator architecture

The ViTGAN Discriminator is mostly a standard Vision Transformer network, with the following modifications:

  • DiffAugment
  • Overlapping Image Patches
  • Use vectorized L2 distance in attention for Discriminator
  • Improved Spectral Normalization (ISN)
  • Balanced Consistency Regularization (bCR)

DiffAugment

For implementating DiffAugment, I used the code below:
[GitHub] [Paper]

Overlapping Image Patches

Creation of the overlapping image patches is implemented with the use of a convolution layer.

Use vectorized L2 distance in attention for Discriminator

[Paper]

Improved Spectral Normalization (ISN)

The ISN implementation is based on the following implementation of Spectral Normalization:
[GitHub] [Paper]

Balanced Consistency Regularization (bCR)

Zhengli Zhao, Sameer Singh, Honglak Lee, Zizhao Zhang, Augustus Odena, Han Zhang; Improved Consistency Regularization for GANs; AAAI 2021 [Paper]

References

SIREN: Implicit Neural Representations with Periodic Activation Functions
Vision Transformer: [Blog Post]
L2 distance attention: The Lipschitz Constant of Self-Attention
Spectral Normalization reference code: [GitHub] [Paper]
Diff Augment: [GitHub] [Paper]
Fourier Position Embedding: [Jupyter Notebook]
Exponential Moving Average: [GitHub]
Balanced Concictancy Regularization (bCR): [Paper]
SyleGAN2 Discriminator: [GitHub]

About

A PyTorch implementation of VITGAN: Training GANs with Vision Transformers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published