This repository provides a PyTorch implementation of SAGAN. Both wgan-gp and wgan-hinge loss are ready, but note that wgan-gp is somehow not compatible with the spectral normalization. Remove all the spectral normalization at the model for the adoption of wgan-gp. Self-attentions are applied before CNN of both discriminator and generator.
- Unsupervised setting (use no label yet)
- Applied: Spectral Normalization, code from here
- Implemented: self-attention module, two-timescale update rule (TTUR), wgan-hinge loss, wgan-gp loss
- Parallel Computation on multi-GPU
- Tensorboard loggings
- Attention visualization on 64 * 64 image
- Create Attention map of 64 * 64 image (4096 * 4096)
- Change custom (hearthstone) dataset
- Create 256*256 image [branch pix256]
Warning: 64*64 is the maximum 2power size of attention map for training in 2 Nvidia GTX 1080 Ti (24GB RAM)
- Python 3.5+
- PyTorch 0.3.0
- opencv-python
- Details in
requirements.txt
$ git clone https://github.com/heykeetae/Self-Attention-GAN.git
$ cd Self-Attention-GAN
# for conda user
$ conda create -n sagan python=3.5
$ conda activate sagan
$ conda install pytorch=0.3.0
$ pip install -r requirements.txt
$ cd data
$ bash download.sh CelebA (404 not found)
# or
$ bash download.sh LSUN
# For Hearthstone player
$ mkdir hearthstone-card-images
$ cd hearthstone-card-images
$ wget https://www.dropbox.com/s/vvaxb4maoj4ri34/hearthstone_card.zip?dl=0
$ unzip hearthstone_card.zip?dl=0
$ python main.py --batch_size 64 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb
# or
$ python main.py --batch_size 64 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun
$ python main.py --batch_size 16 --imsize 64 --dataset hearthstone --adv_loss hinge --version sagan_hearth_at1 --num_workers 16 --use_tensorboard True --parallel True --total_step 100000 --log_step 100
For argument details, please read parameter.py
tensorboard --logdir ./logs/sagan_hearth_at1
$ cd samples/sagan_celeb
# or
$ cd samples/sagan_lsun
# or
$ cd samples/sagan_hearth_at1
Samples generated every 100 iterations are located. The rate of sampling could be controlled via --sample_step (ex, --sample_step 100).
- Colormap from opencv(https://docs.opencv.org/2.4/modules/contrib/doc/facerec/colormaps.html)
- Most attent part shows in RED (1) , most non-attent part shows in BLUE(0)
- Scores are ranged in [0,1]: