Generative models are difficult to train and tune, without exact criteria. Thus, exploring a straight forward pipeline for generative modeling, tuning and benchmarking. Within purview of this (personal) exploration:
- Build models with different structures for MNIST type digit image generation.
- Build models of different sizes - A small GAN and a large GAN
- Build a custom benchmark - FID benchmark for MNIST type image generation.
- Do all of this in Jax.
Generated samples: Training losses: Influenced by this implementation of score based diffusion, with modification for personal hardware, score function, hyperparameters, and some minor structure modifictions.
Quoting:
- Uses the variance-preserving SDE to corrupt the data:
- Trains a score model
$s_\theta$ according to the denoising objective:
$\arg\min_\theta \mathbb{E}{t \sim \mathrm{Uniform}[0, T]}\mathbb{E}{y(0) \sim \mathrm{data}}\mathbb{E}{(y(t)|y(0)) \sim \mathrm{SDE}} \lambda(t) | s\theta(t, y(t)) - \nabla_y \log p(y(t)|y(0)) |_2^2$
Generated samples: Training losses: Inspired by this implementation but modified structured for 28 x 28 MNIST instead of 64 x 64, modified training process and model structure like activations.
Generated samples: Training losses: Small MLP-GAN for basic control. (Also to analyze the training pathologies of GANs in a smaller domain). Based on my own torch implementation.