A PyTorch implementation of a diffusion model for generating MNIST handwritten digits using the DeepInv library. This project is currently in development and serves as a learning exercise for understanding diffusion models.
Work in Progress - This project is actively being developed and is not yet complete.
- β Basic U-Net architecture using DeepInv's DiffUNet
- β MNIST data loading and preprocessing
- β Diffusion process setup (forward noising)
- β Training loop implementation
- β Model saving functionality
- β Image generation from trained model (in progress)
- β Sampling and denoising process
- β Result visualization and evaluation
diffusion-mnist-pytorch/
βββ belajardiffusion.ipynb # Jupyter notebook (main development)
βββ trainingdiffusion.py # Marimo app version (unfinished)
βββ README.md # This file
βββ data/ # MNIST dataset (auto-downloaded)
pip install torch torchvision deepinv marimo matplotlib numpy- PyTorch: Deep learning framework
- DeepInv: Computer vision library with pre-built diffusion models
- Marimo: Interactive notebook environment (for .py version)
- Torchvision: For MNIST dataset and transforms
Open and run the notebook for interactive development:
jupyter notebook belajardiffusion.ipynbRun the Marimo version (incomplete):
marimo run trainingdiffusion.py- Architecture: DiffUNet from DeepInv library
- Input/Output: 1 channel (grayscale)
- Image Size: 32x32 (resized from 28x28)
- Batch Size: 48 (notebook) / 64 (marimo)
- Timesteps: 1000
- Beta Schedule: Linear from 1e-4 to 0.02
- Loss Function: MSE between predicted and actual noise
- Optimizer: Adam (lr=1e-4)
- Epochs: 5 (notebook) / 10 (marimo planned)
- Device: CUDA if available, else CPU
The model adds Gaussian noise to MNIST images over 1000 timesteps:
noisy_imgs = (
sqrt_alphas_cumprod[t, None, None, None] * imgs +
sqrt_one_minus_alphas_cumprod[t, None, None, None] * noise
)The model learns to predict the noise added at each timestep:
- Sample random timestep
t - Add corresponding noise level to clean image
- Train U-Net to predict the added noise
- Minimize MSE loss between predicted and actual noise
- Complete the reverse diffusion process for image generation
- Implement sampling algorithm to generate new digits
- Add result visualization to see generated samples
- Finish the Marimo app version with complete functionality
- Add classifier-free guidance for conditional generation
- Implement different noise schedules (cosine, etc.)
- Add FID/IS metrics for evaluation
- Experiment with different U-Net architectures
- Add interpolation between digits
This project is based on understanding:
- Denoising Diffusion Probabilistic Models (DDPM)
- DeepInv Documentation
- Diffusion model fundamentals and implementation
- No generation capability yet - can only train the model
- Missing reverse process - need to implement sampling
- No evaluation metrics - need to add quality assessment
- Marimo version incomplete - training code is commented out
Using DeepInv's DiffUNet provides:
- Pre-built, tested U-Net architecture
- Proper time embedding handling
- Simplified model setup for learning purposes
This project emphasizes understanding:
- Forward and reverse diffusion processes
- Noise prediction training paradigm
- U-Net architecture for diffusion
- PyTorch implementation details
This is a personal learning project, but suggestions and improvements are welcome!
Educational/Learning project - feel free to use and modify.