This repository contains PyTorch implementations of different type of autoencoders.
Try the interactive demo: https://julienexr-autoencooder-mnist-app-sehuso.streamlit.app/
modules/autoencoder/ae.py-- AE model (encoder + decoder).modules/autoencoder/training.py-- AE training loop.modules/vae/vae.py-- VAE model (probabilistic encoder producing μ and logvar + decoder). Two encoder/decoder variants exist (defaultandpp).modules/vae/training.py-- VAE training loop.modules/cvae/cvae.py-- CVAE model (VAE conditioned on class labels).modules/cvae/training.py-- CVAE training loop.modules/vq_vae/vq_vae.py-- VQ-VAE model (encoder + decoder + vector quantizer).modules/vq_vae/tranformer_prior.py-- Transformer prior over VQ codebook indices.modules/vq_vae/training.py-- VQ-VAE + Transformer prior training loops.main.py-- example entry points. By default it runs AE training then VAE training (see note below).src/visualization.py--Visualizerhelper used bytraining.pyto save reconstructions, PCA plots, interpolations and noise samples intovisu/.src/data.py-- MNIST dataloader helpers.modules/autoencoder/-- AE module README and usage notes.modules/vae/-- VAE module README and usage notes.modules/cvae/-- CVAE module README and usage notes.modules/vq_vae/-- VQ-VAE + Transformer prior README and usage notes.models/AE/,models/VAE/,models/CVAE/,models/VQ-VAE/-- expected checkpoints are saved here (encoder/decoder/codebook state dicts).
Each model has its own README with full explanations and figures :
- Autoencoder (AE): modules/autoencoder/README.md
- Variational Autoencoder (VAE): modules/vae/README.md
- Conditional VAE (CVAE): modules/cvae/README.md
- VQ-VAE + Transformer Prior: modules/vq_vae/README.md
- Clone the repository
git clone git@github.com:JulienExr/Autoencoder-MNIST.git
(HTTPS : git clone https://github.com/JulienExr/Autoencoder-MNIST.git)
cd Autoencoder-MNIST- Create and activate a virtual environment :
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt- Prepare data
MNIST will be downloaded automatically by torchvision into ./data when you run training.
- Run training (CLI)
main.py you can choose the model, dataset, and latent dimension:
You can visualize training outputs by adding the --visualize flag.
python main.py --model AE --dataset mnist --latent_dim 256
python main.py --model VAE --dataset mnist --latent_dim 32Dataset options:
mnist(default)fashion_mnist(more challenging, grayscale clothing items)
Example with Fashion-MNIST:
python main.py --model VAE --dataset fashion_mnist --latent_dim 128- Outputs
- Model checkpoints are saved under
models/AE/,models/VAE/,models/CVAE/, andmodels/VQ-VAE/. - Visual outputs are saved under
visu/<dataset>_<model>/with subfoldersrecon,pca,umap,interp, andnoise. - If you want to try with your own model saved on models/* use :
streamlit run app.py.