Skip to content

PyTorch implementation of the Variational Autoencoder (VAE) model from the Kingma & Welling paper 'Auto-Encoding Variational Bayes', evaluated on MNIST and Fashion MNIST datasets.

Notifications You must be signed in to change notification settings

Ofekirsh/Auto-Encoding-Variational-Bayes

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Auto-Encoding-Variational-Bayes (VAE)

This project implements the models proposed in Auto-Encoding Variational Bayes (VAE) paper written by Diederik P Kingma and Max Welling.

The models were ran over the following datasets:

  • MNIST
  • MNIST Fashion

The results of the models can be found in results.pdf.

Below are samples generated by the trained VAE model on MNIST and Fashion-MNIST datasets:

  • Fashion-MNIST Samples: Images generated after training a VAE on the Fashion-MNIST dataset. The model has learned to capture and reproduce the structure of clothing items like shoes, shirts, and bags.

  • MNIST Samples: Digits generated by a VAE trained on the MNIST dataset.

Here’s a README.md-style section in the format you asked for, tailored to the VAE training project:


Setup

Prerequisites

  • Python 3.x
  • NumPy
  • torch
  • torchvision
  • matplotlib

You can install the dependencies using:

pip install -r requirements.txt

Running the Code

To train the VAE model on either MNIST or Fashion-MNIST:

python main.py

You can customize the training with the following command-line arguments:

Argument Description Default
--dataset Dataset to train on: mnist or fashion-mnist mnist
--batch_size Number of images per mini-batch 128
--epochs Number of training epochs 50
--sample_size Number of images to generate during sampling 64
--latent-dim Dimensionality of the latent space (z vector size) 100
--lr Initial learning rate for the optimizer 1e-3

Example with Custom Parameters:

python main.py --dataset fashion-mnist --batch_size 64 --epochs 40 --latent-dim 50 --lr 0.0005

Output Files

  • Generated Samples: After training, 100 samples are generated and saved in the ./samples/ folder as:

    • fashion-mnist_batch128_mid100_.png
    • mnist_batch128_mid100_.png
  • Loss Tracking: The training and testing loss for each epoch is saved in a pickle file:

    • loss_batches_mnist.pkl or loss_batches_fashion-mnist.pkl

About

PyTorch implementation of the Variational Autoencoder (VAE) model from the Kingma & Welling paper 'Auto-Encoding Variational Bayes', evaluated on MNIST and Fashion MNIST datasets.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages