Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Dey committed Dec 28, 2018
0 parents commit ff6c653
Show file tree
Hide file tree
Showing 56 changed files with 8,762 additions and 0 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# RankGAN

### Abstract
We present a new stage-wise learning paradigm for training generative adversarial networks (GANs). The goal of our work is to progressively strengthen the discriminator and thus, the generators, with each subsequent stage without changing the network architecture. We call this proposed method the RankGAN. We first propose a margin-based loss for the GAN discriminator. We then extend it to a margin-based ranking loss to train the multiple stages of RankGAN. We focus on face images from the CelebA dataset in our work and show visual as well as quantitative improvements in face generation and completion tasks over other GAN approaches, including WGAN and LSGAN.

***

### Overview
![alt text](images/rankgan_idea.png "Stagewise Training using Ranking Mechanism")

Ranking based Progressive Training in RankGAN.

![alt text](images/flowchart_red.png "RankGAN Training Flowchart")

RankGAN Training Flowchart and Architecture.

***

### Contributions
* A progressive training framework where GANs at later stages improve upon their earlier versions.
* A margin-based Ranking Loss function to train GANs.
* Evaluation of GANs based on image completion tasks.

***

### Results
![alt text](images/generation_results.png "Image Generation Results")

Visual and quantitative results on face generation with RankGAN, WGAN and LSGAN.

![alt text](images/image_completion.png "Image Completion Results")

Image Completion Results.

### References

1. Rahul Dey, Felix Juefei-Xu, Vishnu Naresh Boddeti and Marios Savvides. [**RankGAN: A Maximum Margin Ranking GAN for Generating Faces.**](https://arxiv.org/abs/1812.08196) Asian Conference on Computer Vision (ACCV 2018.

47 changes: 47 additions & 0 deletions checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# checkpoints.py

import os
import torch

class Checkpoints:

def __init__(self,args):
self.dir_save = args.save
self.dir_load = args.resume
# self.prevmodel = args.prevmodel
self.prevmodel = None

if os.path.isdir(self.dir_save) == False:
os.makedirs(self.dir_save)

def latest(self, name):
output = {}
if self.dir_load == None:
output['resume'] = None
else:
output['resume'] = self.dir_load

if (self.prevmodel != None):
output['prevmodel'] = self.prevmodel
else:
output['prevmodel'] = None

return output[name]

def save(self, epoch, model, best):
if best == True:
output = {}
num = len(model)
for key, value in model[0].items():
output[key] = value.state_dict()
torch.save(output, '%s/model_%d_epoch_%d.pth' %
(self.dir_save, num, epoch))

def load(self, filename):
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
model = torch.load(filename)
else:
print("=> no checkpoint found at '{}'".format(filename))

return model
55 changes: 55 additions & 0 deletions completion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#completion.py

import torch
import random
import torchvision
from model import Model
from config import parser
from dataloader import Dataloader
from checkpoints import Checkpoints
from evaluation import Evaluate
from generation import Generator
import os
import datetime
import utils
import copy

# parse the arguments
args = parser.parse_args()
random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
args.save = os.path.join(args.result_path, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S/'), 'results')
args.logs = os.path.join(args.result_path, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S/'), 'logs')
utils.saveargs(args)

# initialize the checkpoint class
checkpoints = Checkpoints(args)

# Create Model
models = Model(args)
gogan_model, criterion = models.setup(checkpoints)
netD = gogan_model[0]
netG = gogan_model[1]
netE = gogan_model[2]

if args.netD is not '':
checkpointD = checkpoints.load(args.netD)
netD.load_state_dict(checkpointD)
if args.netG is not '':
checkpointG = checkpoints.load(args.netG)
netG.load_state_dict(checkpointG)
if args.netE is not '':
checkpointE = checkpoints.load(args.netE)
netE.load_state_dict(checkpointE)

# Data Loading
dataloader = Dataloader(args)
test_loader = dataloader.create("Test", shuffle=False)

# The trainer handles the training loop and evaluation on validation set
evaluate = Evaluate(args, netD, netG, netE)
# generator = Generator(args, netD, netG, netE)

# test for a single epoch
test_loss = evaluate.complete(test_loader)
# loss = generator.generate_one(test_loader)
140 changes: 140 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# config.py
import os
import datetime
import argparse

result_path = "results/"
result_path = os.path.join(result_path, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S/'))

parser = argparse.ArgumentParser(description='Your project title goes here')

# ======================== Data Setings ============================================
parser.add_argument('--dataset-test', type=str, default='CELEBA', metavar='', help='name of training dataset')
parser.add_argument('--dataset-train', type=str, default='CELEBA', metavar='', help='name of training dataset')
parser.add_argument('--split_test', type=float, default=None, metavar='', help='test split')
parser.add_argument('--split_train', type=float, default=None, metavar='', help='train split')
parser.add_argument('--dataroot', type=str, default=None, metavar='', help='path to the data')
parser.add_argument('--result-path', type=str, default=result_path, help='path where to save results')
parser.add_argument('--resume', type=str, default=None, metavar='', help='full path of models to resume training')
parser.add_argument('--nclasses', type=int, default=None, metavar='', help='number of classes for classification')
parser.add_argument('--input-filename-test', type=str, default=None, metavar='', help='input test filename for filelist and folderlist')
parser.add_argument('--label-filename-test', type=str, default=None, metavar='', help='label test filename for filelist and folderlist')
parser.add_argument('--input-filename-train', type=str, default=None, metavar='', help='input train filename for filelist and folderlist')
parser.add_argument('--label-filename-train', type=str, default=None, metavar='', help='label train filename for filelist and folderlist')
parser.add_argument('--loader-input', type=str, default=None, metavar='', help='input loader')
parser.add_argument('--loader-label', type=str, default=None, metavar='', help='label loader')

# ======================== Network Model Setings ===================================
parser.add_argument('--nchannels', type=int, default=3, metavar='', help='number of input channels')
parser.add_argument('--resolution-high', type=int, default=64, metavar='', help='image resolution height')
parser.add_argument('--resolution-wide', type=int, default=64, metavar='', help='image resolution width')
parser.add_argument('--ndim', type=int, default=None, metavar='', help='number of feature dimensions')
parser.add_argument('--nunits', type=int, default=None, metavar='', help='number of units in hidden layers')
parser.add_argument('--dropout', type=float, default=None, metavar='', help='dropout parameter')
parser.add_argument('--net-type', type=str, default='dcgan', metavar='', help='type of network')
parser.add_argument('--length-scale', type=float, default=None, metavar='', help='length scale')
parser.add_argument('--tau', type=float, default=None, metavar='', help='Tau')
parser.add_argument('--mini-batch-disc', action='store_true', default=False, help='enable minibatch discrimination in discriminator')

# ======================== Training Settings =======================================
parser.add_argument('--cuda', action='store_true', default=False, help='run on gpu')
parser.add_argument('--ngpu', type=int, default=1, metavar='', help='number of gpus to use')
parser.add_argument('--batch-size', type=int, default=32, metavar='', help='batch size for training')
parser.add_argument('--nepochs', type=int, default=None, metavar='', help='number of epochs to train')
parser.add_argument('--niters', type=int, default=None, metavar='', help='number of iterations at test time')
parser.add_argument('--epoch-number', type=int, default=None, metavar='', help='epoch number')
parser.add_argument('--nthreads', type=int, default=2, metavar='', help='number of threads for data loading')
parser.add_argument('--manual-seed', type=int, default=101, metavar='', help='manual seed for randomness')
parser.add_argument('--port', type=int, default=8097, metavar='', help='port for visualizing training at http://localhost:port')
parser.add_argument('--env', type=str, default='main', help='visdom environment name')
parser.add_argument('--dataset-fraction', type=float, default=1, help='fraction of dataset to train (between 0-1)')
parser.add_argument('--plot-update-interval', type=int, default=30, help='number of iterations per plot update')

# ======================== Hyperparameter Setings ==================================
parser.add_argument('--optim-method', type=str, default='Adam', metavar='', help='the optimization routine ')
parser.add_argument('--learning-rate-vae', type=float, default=1e-4, metavar='', help='learning rate for vae')
parser.add_argument('--learning-rate-dis', type=float, default=5e-5, metavar='', help='learning rate for discriminator')
parser.add_argument('--learning-rate-gen', type=float, default=1e-7, metavar='', help='learning rate for generator')
parser.add_argument('--learning-rate-decay', type=float, default=0.8, metavar='', help='learning rate decay')
parser.add_argument('--momentum', type=float, default=0, metavar='', help='momentum')
parser.add_argument('--weight-decay', type=float, default=0, metavar='', help='weight decay')
parser.add_argument('--stage1-weight-decay', type=float, default=0.5, metavar='', help='stage 1 weight decay for hinge loss')
parser.add_argument('--adam-beta1', type=float, default=0.5, metavar='', help='Beta 1 parameter for Adam')
parser.add_argument('--adam-beta2', type=float, default=0.999, metavar='', help='Beta 2 parameter for Adam')
parser.add_argument('--gp', action='store_true', default=False, help='use gradient penalty')
parser.add_argument('--gp-lambda', type=float, default=10, help="gradient penalty lambda")
parser.add_argument('--scheduler-patience', type=int, default=500, help='patience value for lr scheduler')
parser.add_argument('--scheduler-maxlen', type=int, default=1000, help='patience value for lr scheduler')

# ======================== GoGAN Setings ==================================
parser.add_argument('--stage-epochs', type=int, default=10, help='number of epochs per gogan stage')
parser.add_argument('--num-stages', type=int, default=10, help='number of gogan stages')
parser.add_argument('--margin', type=float, default=2.0, help='initial margin of gogan loss')
parser.add_argument('--weight-gan-final', type=float, default=1.0, help='weight of discriminator loss')
parser.add_argument('--weight-vae-init', type=float, default=1.0, help='weight of mse loss')
parser.add_argument('--ngf', type=int, default=32)
parser.add_argument('--ndf', type=int, default=32)
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--clamp-lower', type=float, default=-0.01, help='WGAN lower weight clip')
parser.add_argument('--clamp-upper', type=float, default=0.01, help='WGAN upper weight clip')
parser.add_argument('--d-iter', type=int, default=5, help='number of discriminator iterations per generation iteration')
parser.add_argument('--g-iter', type=int, default=1, help='number of generator iterations per discriminator iteration')
parser.add_argument('--netG', default='', help="path to netG (to continue training)")
parser.add_argument('--netD', default='', help="path to netD (to continue training)")
parser.add_argument('--netE', default='', help="path to Encoder (to continue training)")
parser.add_argument('--prevD', default='', help="path to prevD (to continue training)")
parser.add_argument('--prevG', default='', help="path to prevG (to continue training)")
parser.add_argument('--vae-loss-type', default='l2', help='type of vae loss l1 or l2')
parser.add_argument('--disc-diff-weight', type=float, default=1.0, help="weightage of discriminator difference loss")
parser.add_argument('--weight-kld', type=float, default=10.0, help='weightage of kl-divergence loss')
parser.add_argument('--start-stage', type=int, default=0, help='starging stage (0/1/2/...)')
parser.add_argument('--normalize', action='store_true', default=False, help='whether to have batch-norm')
parser.add_argument('--gogan-type', type=str, default="vae", help="no_vae/vae_no_gen")
parser.add_argument('--norm-type', type=str, default='batch', help="type of normalization to use in models")
parser.add_argument('--wgan', action='store_true', default=False, help='whether to use wgan loss in first stage of GAN')
parser.add_argument('--extra-D-cap', action='store_true', default=False, help='whether to add extra capacity to the discriminator')
parser.add_argument('--extra-G-cap', action='store_true', default=False, help='whether to add extra capacity to the generator')
parser.add_argument('--correlation-sigma', type=float, default=10.0, help='variance of impulse in correlation loss for VAE')
parser.add_argument('--add-capacity', action='store_true', default=False, help='whether to add extra layer to the discriminator to increase capacity')
parser.add_argument('--add-clamp', action='store_true', default=False, help='whether to change clamping in the discriminator to increase capacity')
parser.add_argument('--disc-optimize', action='store_true', default=False, help='optimize discriminator before training gogan')
parser.add_argument('--gen-gamma', type=float, default=0, help='curriculum learning gamma')
parser.add_argument('--add-noise', action='store_true', default=False, help='adds noise to the discriminator input')
parser.add_argument('--noise-var', type=float, default=0.1, help='std of noise to be added to GAN training')
parser.add_argument('--gp-norm', action='store_true', default=False, help='penalizes sum of gradient squares')
parser.add_argument('--rank-weight', type=float, default=1, help='weight of discriminator ranking loss')
parser.add_argument('--adaptive-iter', action='store_false', default=True, help='enables adaptive iterations for discriminator and generator')
parser.add_argument('--use-upsampling', action='store_true', default=False, help='use upsampling in dcgan')
parser.add_argument('--optimize-mse', action='store_true', default=False, help='wheter to optimize mse during gogan training')
parser.add_argument('--weight-mse', type=float, default=1, help='weight for mse loss during gogan training')
parser.add_argument('--n-extra-layers', type=int, default=0, help='number of extra layers in DCGAN architecture')

# ======================== Image Completion Setings ==================================
parser.add_argument('--disc-loss-weight', type=float, default=0.1, help='weight for discriminator loss in image completion')
parser.add_argument('--ssim-weight', type=float, default=1000, help='weight of ssim loss')
parser.add_argument('--citers', type=int, default=100, help='number of iterations for image completion')
parser.add_argument('--scale', type=float, default=0.2, help='mask scale for image completion')
parser.add_argument('--use-encoder', action='store_true', default=False, help='whether to use encoder for image completion or not')
parser.add_argument('--blend', action='store_true', default=False, help='enable poisson blending')
parser.add_argument('--mask-type', type=str, default='central', help='mask type (central/periocular)')
parser.add_argument('--netG1', default='', help="path to netG1 (to continue training)")
parser.add_argument('--netG2', default='', help="path to netG2 (to continue training)")
parser.add_argument('--netG3', default='', help="path to netG3 (to continue training)")
parser.add_argument('--netG4', default='', help="path to netG4 (to continue training)")
parser.add_argument('--netG5', default='', help="path to netG5 (to continue training)")
parser.add_argument('--start-index', type=int, default=0, help="start index of images")
parser.add_argument('--disc-type', type=str, default='wgan', help='discriminator loss type for image completion')

# ======================== OpenFace Setings ==================================
parser.add_argument('--model', type=str, default='', help="model path")
parser.add_argument('--splits', type=int, default=1, help="number of splits for computing inception score")

# ======================== GMM Setings ==================================
parser.add_argument('--num-gaus', type=int, default=2, help='number of Gaussians in GMM')
parser.add_argument('--gmm-dim', type=int, default=1, help='dimensionality of GMM')
parser.add_argument('--num-samples', type=int, default=10000, help='number of GMM data samples to generate')
parser.add_argument('--gmm-range', type=float, default=3.0, help='range of GMM')
parser.add_argument('--gmm-hidden', type=int, default=8, help='dimensionality of hidden layers in GMM model')
parser.add_argument('--gmm-nlayers', type=int, default=3, help='number of layers in GMM model')

args = parser.parse_args()
Loading

0 comments on commit ff6c653

Please sign in to comment.