Skip to content

Lornatang/tf-gans

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

65 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generative-Adversarial-Networks

paper

Author: Lorna

Email: shiyipaisizuo@gmail.com

Chinese version

Requirements

  • GPU: A TiTAN V or later.
  • Disk: 128G SSD.
  • Python version: python3.5 or later.
  • CUDA: cuda10.
  • CUDNN: cudnn7.4.5 or later.
  • Tensorflow-gpu: 2.0.0-alpla0.

Run this command.

pip install -r requirements.txt

What are GANs?

Generative Adversarial Networks (GANs) are one of the most interesting ideas in computer science today. Two models are trained simultaneously by an adversarial process. A generator ("the artist") learns to create images that look real, while a discriminator ("the art critic") learns to tell real images apart from fakes.

During training, the generator progressively becomes better at creating images that look real, while the discriminator becomes better at telling them apart. The process reaches equilibrium when the discriminator can no longer distinguish real images from fakes.

The following animation shows a series of images produced by the generator as it was trained for 50 epochs. The images begin as random noise, and increasingly resemble hand written digits over time.

1.Introduction

1.1 Theory

This is the flow chart of GAN. GAN

GAN's main source of inspiration is zero-sum game thoughts in game theory, is applied to the deep learning neural network, is the by generating network G (Generator) and discriminant D (Discriminator) network game constantly, thus make G learn data distribution, if used on the image to generate the training is completed, G can generate lifelike image from a random number. The main functions of G and D are:

  • G is a generating network, it receives a random noise z (random number), through the noise to generate images.
  • D is a network for judging whether an image is "real". Its input parameter is x, x represents a picture, and the output D (x) represents the probability that x is a real picture. If it is 1, it represents 100% real picture, while if it is 0, it represents an impossible picture.

In the process of training, the goal of generating network G is to generate real images as much as possible to cheat network D. And the goal of D is to try to distinguish the fake image generated by G from the real one. In this way, G and D constitute a dynamic "game process", and the final equilibrium point is the Nash equilibrium point..

1.2 Architecture

By optimizing the target, we can adjust the parameter of the probability generation model, so that the probability distribution and the real data distribution can be as close as possible.

So how do you define an appropriate optimization goal or a loss? In the traditional generation model, the likelihood of data is generally adopted as the optimization target, but GAN innovatively USES another optimization target.

  • Firstly, it introduces a discriminant model (common ones include support vector machine and multi-layer neural network).
  • Secondly, its optimization process is to find a Nash equilibrium between the generative model and the discriminant model.

A learning framework established by GAN is actually a simulation game between generating model and discriminating model. The purpose of generating models is to imitate, model and learn the distribution law of real data as much as possible. The discriminant model is to determine whether an input data obtained by itself comes from a real data distribution or a generated model. Through the continuous competition between these two internal models, the ability to generate and distinguish the two models is improved.

When a model has very strong ability to distinguish. if the generated data of the model can still be confused and cannot be judged correctly, then we think that the generated model has actually learned the distribution of real data.

1.3 GAN characteristics

characteristics:

  • low compared to the traditional model, there are two different networks, rather than a single network, USES a confrontation training methods and training ways.

  • low GAN gradient G in the update information from discriminant D, rather than from sample data.

advantages:

  • low GAN is an emergent model, compared to other generation model (boltzmann machine and GSNs) only by back propagation, without the need for a complicated markov chain.

  • low compared to all other model, GAN can produce more clearly, the real sample

  • low GAN is a kind of unsupervised learning training, and can be widely used in the field of a semi-supervised learning and unsupervised learning.

  • Compared with the variational self-encoder, GANs does not introduce any deterministic bias, and the variational methods introduce deterministic bias, because they optimize the lower bound of logarithmic likelihood rather than the likelihood itself, which seems to cause the instance generated by VAEs to be more fuzzy than GANs.

  • low compared with VAE, GANs variational lower bound, if the discriminator training is good, then the generator can learn to perfect the training sample distribution. In other words, GANs, gradual consistent, but the VAE is biased.

  • GAN applied to some scenes, such as picture style transfer, super resolution, image completion, noise removal, to avoid the loss of function design difficulties, regardless of three seven and twenty-one, as long as there is a benchmark, directly on the discriminator, the rest of the training to the confrontation.

disadvantages:

  • training GAN needs to reach Nash equilibrium, sometimes it can be achieved by gradient descent method, sometimes it can't. We haven't found a good method to achieve Nash equilibrium, so training GAN is unstable compared with VAE or PixelRNN, but I think it is more stable than training boltzmann machine in practice.

  • GAN is not suitable for processing discrete data, such as text.

  • GAN has the problems of unstable training, gradient disappearance and mode collapse.

2.Implements

2.1 Load and prepare the dataset

You will use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data.

import tensorflow as tf


def load_dataset(mnist_size, mnist_batch_size, cifar_size, cifar_batch_size,):
  """ load mnist and cifar10 dataset to shuffle.

  Args:
    mnist_size: mnist dataset size.
    mnist_batch_size: every train dataset of mnist.
    cifar_size: cifar10 dataset size.
    cifar_batch_size: every train dataset of cifar10.

  Returns:
    mnist dataset, cifar10 dataset

  """
  # load mnist data
  (mnist_train_images, mnist_train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

  # load cifar10 data
  (cifar_train_images, cifar_train_labels), (_, _) = tf.keras.datasets.cifar10.load_data()

  mnist_train_images = mnist_train_images.reshape(mnist_train_images.shape[0], 28, 28, 1).astype('float32')
  mnist_train_images = (mnist_train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

  cifar_train_images = cifar_train_images.reshape(cifar_train_images.shape[0], 32, 32, 3).astype('float32')
  cifar_train_images = (cifar_train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

  # Batch and shuffle the data
  mnist_train_dataset = tf.data.Dataset.from_tensor_slices(mnist_train_images)
  mnist_train_dataset = mnist_train_dataset.shuffle(mnist_size).batch(mnist_batch_size)

  cifar_train_dataset = tf.data.Dataset.from_tensor_slices(cifar_train_images)
  cifar_train_dataset = cifar_train_dataset.shuffle(cifar_size).batch(cifar_batch_size)

  return mnist_train_dataset, cifar_train_dataset

2.2 Create the models

Both the generator and discriminator are defined using the Keras Sequential API.

2.2.1 Make Generator model

Only the most basic form of full connection is used here for the neural network architecture. Except the first layer which does not use normalization, the other layers are all defined by the linear structure of full connection -> normalization ->LeakReLU, and the specific parameters are explained in the code below.

import tensorflow as tf
from tensorflow.python.keras import layers


def make_generator_model(dataset='mnist'):
  """ implements generate.

  Args:
    dataset: mnist or cifar10 dataset. (default='mnist'). choice{'mnist', 'cifar'}.

  Returns:
    model.

  """
  model = tf.keras.models.Sequential()
  model.add(layers.Dense(256, input_dim=100))
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(layers.Dense(512))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(layers.Dense(1024))
  model.add(layers.BatchNormalization())
  model.add(layers.LeakyReLU(alpha=0.2))

  if dataset == 'mnist':
    model.add(layers.Dense(28 * 28 * 1, activation='tanh'))
    model.add(layers.Reshape((28, 28, 1)))
  elif dataset == 'cifar':
    model.add(layers.Dense(32 * 32 * 3, activation='tanh'))
    model.add(layers.Reshape((32, 32, 3)))

  return model

2.2.2 Make Discriminator model

The discriminator is a CNN-based image classifier.

import tensorflow as tf
from tensorflow.python.keras import layers


def make_discriminator_model(dataset='mnist'):
  """ implements discriminate.

  Args:
    dataset: mnist or cifar10 dataset. (default='mnist'). choice{'mnist', 'cifar'}.

  Returns:
    model.

  """
  model = tf.keras.models.Sequential()
  if dataset == 'mnist':
    model.add(layers.Flatten(input_shape=[28, 28, 1]))
  elif dataset == 'cifar':
    model.add(layers.Flatten(input_shape=[32, 32, 3]))

  model.add(layers.Dense(1024))
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(layers.Dense(512))
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(layers.Dense(256))
  model.add(layers.LeakyReLU(alpha=0.2))

  model.add(layers.Dense(1, activation='sigmoid'))

  return model

2.3 Define the loss and optimizers

2.3.1 Define loss functions and optimizers for both models.

cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

2.3.2 Discriminator loss

This method quantifies how well the discriminator is able to distinguish real images from fakes. It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions on fake (generated) images to an array of 0s.

def discriminator_loss(real_output, fake_output):
  """ This method quantifies how well the discriminator is able to distinguish real images from fakes.
      It compares the discriminator's predictions on real images to an array of 1s, and the discriminator's predictions
      on fake (generated) images to an array of 0s.

  Args:
    real_output: origin pic.
    fake_output: generate pic.

  Returns:
    real loss + fake loss

  """
  real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  total_loss = real_loss + fake_loss

  return total_loss

2.3.3 Generator loss

The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1). Here, we will compare the discriminators decisions on the generated images to an array of 1s.

def generator_loss(fake_output):
  """ The generator's loss quantifies how well it was able to trick the discriminator.
      Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1).
      Here, we will compare the discriminators decisions on the generated images to an array of 1s.

  Args:
    fake_output: generate pic.

  Returns:
    loss

  """
  return cross_entropy(tf.ones_like(fake_output), fake_output)

2.3.4 optimizer

The discriminator and the generator optimizers are different since we will train two networks separately.

def generator_optimizer():
  """ The training generator optimizes the network.

  Returns:
    optim loss.

  """
  return tf.keras.optimizers.Adam(lr=1e-4)


def discriminator_optimizer():
  """ The training discriminator optimizes the network.

  Returns:
    optim loss.

  """
  return tf.keras.optimizers.Adam(lr=1e-4)

2.4 Save checkpoints

This notebook also demonstrates how to save and restore models, which can be helpful in case a long running training task is interrupted.

import os
import tensorflow as tf


def save_checkpoints(generator, discriminator, generator_optimizer, discriminator_optimizer, save_path):
  """ save gan model

  Args:
    generator: generate model.
    discriminator: discriminate model.
    generator_optimizer: generate optimizer func.
    discriminator_optimizer: discriminator optimizer func.
    save_path: save gan model dir path.

  Returns:
    checkpoint path

  """
  checkpoint_dir = save_path
  checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
  checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                   discriminator_optimizer=discriminator_optimizer,
                                   generator=generator,
                                   discriminator=discriminator)

  return checkpoint_dir, checkpoint, checkpoint_prefix

2.5 train

2.5. 1 Define the training loop

The training loop begins with generator receiving a random seed as input. That seed is used to produce an image. The discriminator is then used to classify real images (drawn from the training set) and fakes images (produced by the generator). The loss is calculated for each of these models, and the gradients are used to update the generator and discriminator.

from dataset.load_dataset import load_dataset
from network.generator import make_generator_model
from network.discriminator import make_discriminator_model
from util.loss_and_optim import generator_loss, generator_optimizer
from util.loss_and_optim import discriminator_loss, discriminator_optimizer
from util.save_checkpoints import save_checkpoints
from util.generate_and_save_images import generate_and_save_images

import tensorflow as tf
import time
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mnist', type=str,
                    help='use dataset {mnist or cifar}.')
parser.add_argument('--epochs', default=50, type=int,
                    help='Epochs for training.')
args = parser.parse_args()
print(args)

# define model save path
save_path = 'training_checkpoint'

# create dir
if not os.path.exists(save_path):
  os.makedirs(save_path)

# define random noise
noise = tf.random.normal([16, 100])

# load dataset
mnist_train_dataset, cifar_train_dataset = load_dataset(60000, 128, 50000, 64)

# load network and optim paras
generator = make_generator_model(args.dataset)
generator_optimizer = generator_optimizer()

discriminator = make_discriminator_model(args.dataset)
discriminator_optimizer = discriminator_optimizer()

checkpoint_dir, checkpoint, checkpoint_prefix = save_checkpoints(generator,
                                                                 discriminator,
                                                                 generator_optimizer,
                                                                 discriminator_optimizer,
                                                                 save_path)


# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
  """ break it down into training steps.

  Args:
    images: input images.

  """
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    generated_images = generator(noise, training=True)

    real_output = discriminator(images, training=True)
    fake_output = discriminator(generated_images, training=True)

    gen_loss = generator_loss(fake_output)
    disc_loss = discriminator_loss(real_output, fake_output)

  gradients_of_generator = gen_tape.gradient(gen_loss,
                                             generator.trainable_variables)
  gradients_of_discriminator = disc_tape.gradient(disc_loss,
                                                  discriminator.trainable_variables)

  generator_optimizer.apply_gradients(
    zip(gradients_of_generator, generator.trainable_variables))
  discriminator_optimizer.apply_gradients(
    zip(gradients_of_discriminator, discriminator.trainable_variables))


def train(dataset, epochs):
  """ train op

  Args:
    dataset: mnist dataset or cifar10 dataset.
    epochs: number of iterative training.

  """
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    generate_and_save_images(generator,
                             epoch + 1,
                             noise,
                             save_path)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix=checkpoint_prefix)

    print(f'Time for epoch {epoch+1} is {time.time()-start:.3f} sec.')

  # Generate after the final epoch
  generate_and_save_images(generator,
                           epochs,
                           noise,
                           save_path)


if __name__ == '__main__':
  if args.dataset == 'mnist':
    train(mnist_train_dataset, args.epochs)
  else:
    train(cifar_train_dataset, args.epochs)

2.6 Generate and save images

from matplotlib import pyplot as plt


def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

3.Common problems

3.1 why do optimizers in GAN not often use SGD

  • SGD is easy to shake, easy to make GAN training unstable.

-The purpose of GAN is to find the Nash equilibrium point in the higher-dimensional non-convex parameter space. The Nash equilibrium point of GAN is a saddle point, but SGD will only find the local minimum value, because SGD solves the problem of finding the minimum value, and GAN is a game problem.

3.2 Why GAN is not suitable for processing text data

  • Compared text data are discrete image data, because for text, usually need to map a word as a high dimensional vector, and finally forecasts the output is a one - hot vector, assuming softmax output is (0.2, 0.3, 0.1, 0.2, 0.15, 0.05) then becomes onehot,1,0,0,0,0 (0), if the softmax output is (0.2, 0.25, 0.2, 0.1, 0.15, 0.1), one - is still hot (0, 1, 0, 0, 0, 0). Therefore, for the generator, G outputs different results, but D gives the same discriminant result, and cannot well transmit the gradient update information to G, so the discriminant of D's final output is meaningless.

  • In addition, the loss function of GAN is JS divergence, which is not suitable for measuring the distance between distributions that do not want to intersect.

3.3 some skills to train GAN

  • Input normalized to (-1, 1), last level of activation function using tanh BEGAN (exception)

  • Using wassertein GAN's loss function,

  • If you have label data, try to use labels. Some people suggest that it is good to use inverted labels, and use label smoothing, unilateral label smoothing or bilateral label smoothing

  • Using mini-batch norm, if you do not use batch norm, you can use instance norm or weight norm

  • Avoid using RELU and pooling layers to reduce the possibility of sparse gradient, and leakrelu activation function can be used

  • The optimizer chooses ADAM as far as possible, and the learning rate should not be too large. The initial 1e-4 can be referred to. In addition, the learning rate can be continuously reduced as the training goes on.

  • Adding gaussian noise to the network layer of D is equivalent to a kind of regularization

3.4 Model collapse reason

Generally, GAN is not stable in training, and the results are very poor. However, even if the training time is extended, it cannot be well improved.

The specific reasons can be explained as follows: Is against training methods used by GAN, G gradient update from D, G generated is good, so D what to say to me. Specifically, G will generate a sample and give it to D for evaluation. D will output the probability (0-1) that the generated false sample is a true sample, which is equivalent to telling G how authentic the generated sample is. G will improve itself and improve the probability value of D's output according to this feedback. But if one G generated samples may not be true, but D gives the correct evaluation, or is the result of a G generated some characteristics have been the recognition of D, then G output will think I'm right, so I so output D surely will also give a high evaluation, G actually generated is not how, but they are two so self-deception, lead to the resulting results lack some information, characteristics.

4. GAN in the application of life

  • GAN itself is a generative model, so data generation is the most common, the most common is image generation, commonly used DCGAN WGAN BEGAN, personal feeling in BEGAN the best and the most simple.

  • GAN itself is also a model of unsupervised learning. So it is widely used in unsupervised learning and semi-supervised learning.

  • GAN not only plays a role in the generation field, but also plays a role in the classification field. To put it simply, it is to replace the discriminator as a classifier and do multiple classification tasks, while the generator still does generation tasks and assists the classifier training.

  • GAN can be combined with reinforcement learning. A good example is seq-gan.

  • At present, GAN is an interesting application in image style transfer, image noise reduction and restoration, and image super resolution, all of which have good results.

TODO

  • Write metrics code.
  • Create GIF.

Reference

Sakura55

Releases

No releases published

Packages

No packages published

Languages