Skip to content

Commit

Permalink
Merge pull request #158 from Dsantra92/fix/dcgan
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir authored Jan 3, 2023
2 parents d6cbc78 + c33713d commit 3469eb6
Showing 1 changed file with 17 additions and 23 deletions.
40 changes: 17 additions & 23 deletions tutorialposts/2021-10-08-dcgan-mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This is a beginner level tutorial for generating images of handwritten digits us

A GAN is composed of two sub-models - the **generator** and the **discriminator** acting against one another. The generator can be considered as an artist who draws (generates) new images that look real, whereas the discriminator is a critic who learns to tell real images apart from fakes.

![](../../assets/tutorialposts/2021-10-8-dcgan-mnist/cat_gan.png)
![](../../assets/tutorialposts/2021-10-08-dcgan-mnist/cat_gan.png)

The GAN starts with a generator and discriminator which have very little or no idea about the underlying data. During training, the generator progressively becomes better at creating images that look real, while the discriminator becomes better at telling them apart. The process reaches equilibrium when the discriminator can no longer distinguish real images from fakes.

Expand All @@ -24,9 +24,9 @@ This tutorial demonstrates the process of training a DC-GAN on the [MNIST datase

~~~
<br><br>
<p align="center">
<img src="../../assets/tutorialposts/2021-10-8-dcgan-mnist/output.gif" align="middle" width="200">
</p>
<div style="text-align:center">
<img src="../../assets/tutorialposts/2021-10-08-dcgan-mnist/output.gif" width="200">
</div>
~~~

## Setup
Expand All @@ -43,7 +43,7 @@ Pkg.add(["Images", "Flux", "MLDatasets", "CUDA", "Parameters"])
```
*Note: Depending on your internet speed, it may take a few minutes for the packages install.*

<br>
\
After installing the libraries, load the required packages and functions:
```julia
using Base.Iterators: partition
Expand All @@ -59,7 +59,7 @@ using Flux.Losses: logitbinarycrossentropy
using MLDatasets: MNIST
using CUDA
```
<br>
\
Now we set default values for the learning rates, batch size, epochs, the usage of a GPU (if available) and other hyperparameters for our model.

```julia
Expand Down Expand Up @@ -116,7 +116,6 @@ We will also apply the weight initialization method mentioned in the original DC
# sampled from a Gaussian distribution with μ=0 and σ=0.02
dcgan_init(shape...) = randn(Float32, shape) * 0.02f0
```
<br>

```julia
function Generator(latent_dim)
Expand All @@ -137,7 +136,7 @@ function Generator(latent_dim)
)
end
```
<br>
\
Time for a small test!! We create a dummy generator and feed a random vector as a seed to the generator. If our generator is initialized correctly it will return an array of size (28, 28, 1, `batch_size`). The `@assert` macro in Julia will raise an exception for the wrong output size.

```julia
Expand All @@ -150,10 +149,7 @@ gen_image = generator(noise)
@assert size(gen_image) == (28, 28, 1, 3)
```

<br>
Our generator model is yet to learn the correct weights, so it does not produce a recognizable image for now. To train our poor generator we need its equal rival, the *discriminator*.
<br>
<br>

### Discriminator

Expand Down Expand Up @@ -187,15 +183,14 @@ discriminator = Discriminator()
logits = discriminator(gen_image)
@assert size(logits) == (1, 3)
```
<br>

Just like our dummy generator, the untrained discriminator has no idea about what is a real or fake image. It needs to be trained alongside the generator to output positive values for real images, and negative values for fake images.

## Loss functions for GAN

In a GAN problem, there are only two labels involved: fake and real. So Binary CrossEntropy is an easy choice for a preliminary loss function.

But even if Flux's `binarycrossentropy` does the job for us, due to numerical stability it is always preferred to compute cross-entropy using logits. Flux provides [logitbinarycrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitbinarycrossentropy) specifically for this purpose. Mathematically it is equivalent to `binarycrossentropy(σ(ŷ), y, kwargs...).`
<br>

### Discriminator Loss

Expand All @@ -213,15 +208,14 @@ function discriminator_loss(real_output, fake_output)
return real_loss + fake_loss
end
```
<br>
### Generator Loss

The generator's loss quantifies how well it was able to trick the discriminator. Intuitively, if the generator is performing well, the discriminator will classify the fake images as real (or 1).

```julia
generator_loss(fake_output) = logitbinarycrossentropy(fake_output, 1)
```
<br>
\
We also need optimizers for our network. Why you may ask? Read more [here](https://towardsdatascience.com/overview-of-various-optimizers-in-neural-networks-17c1be2df6d5). For both the generator and discriminator, we will use the [ADAM optimizer](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.ADAM).

## Utility functions
Expand Down Expand Up @@ -254,7 +248,7 @@ function train_discriminator!(gen, disc, real_img, fake_img, opt, ps, hparams)
return disc_loss
end
```
<br>
\
We define a similar function for the generator.

```julia
Expand All @@ -268,8 +262,7 @@ function train_generator!(gen, disc, fake_img, opt, ps, hparams)
return gen_loss
end
```
<br>

\
Now that we have defined every function we need, we integrate everything into a single `train` function where we first set up all the models and optimizers and then train the GAN for a specified number of epochs.

```julia
Expand Down Expand Up @@ -337,7 +330,7 @@ function train(hparams)
return nothing
end
```
<br>

Now we finally get to train the GAN:

```julia
Expand All @@ -359,10 +352,11 @@ images = load.(img_paths)
gif_mat = cat(images..., dims=3)
save("./output.gif", gif_mat)
```
<br>
<p align="center">
<img src="../../assets/tutorialposts/2021-10-8-dcgan-mnist/output.gif" align="middle" width="200">
</p>
~~~
<div style="text-align:center">
<img src="../../assets/tutorialposts/2021-10-08-dcgan-mnist/output.gif" width="200">
</div>
~~~

## Resources & References
- [The DCGAN implementation in Model Zoo.](http=s://github.com/FluxML/model-zoo/blob/master/vision/dcgan_mnist/dcgan_mnist.jl)
Expand Down

0 comments on commit 3469eb6

Please sign in to comment.