Skip to content

This notebook implements a denoising diffusion probabilistic model (DDPM) for generating abstract art paintings.

Notifications You must be signed in to change notification settings

andrinr/jax-diffusion

Repository files navigation

Image diffusion

This notebook implements a denoising diffusion probabilistic model (DDPM) for generating abstract art paintings.

The noise prediction model is an attention UNet trained on 512 images of size 64x64. We split the data into 32 batches of 16 images each and train for 300 epochs using Adam optimizer with learning rate 1e-4. The denoising diffusion process uses 1000 timesteps.

When sampling random noise (there is no conditioning), the results looks as follows:

Random Noise Samples

the .gif depicts 5 different samples that are stacked horizontally. In the following image we observe how the model learned to predict the noise

Denoising Process. The noise sampling is shown in the plot below:

Noise Sampling

Finally a plot of the training loss over epochs is shown below:

Training Loss

The code for the UNet architecture is adapted from here. We use Equinox as the neural network library, Optax for optimization, and JAX for automatic differentiation. The training takes around 10 minutes on a laptop with a NVIDIA RTX 4070 GPU 8 GB VRAM.

About

This notebook implements a denoising diffusion probabilistic model (DDPM) for generating abstract art paintings.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published