|
1 | | -I used this for the mnist dataset, but really the dataset is quite easily changeable while the rest of the logic is mostly the same and transferable for more large scale and complicated datasets. |
| 1 | +# Latent Diffusion MNIST Experiment |
| 2 | + |
| 3 | +Latent Diffusion MNIST Experiment is a Python project exploring the implementation of a latent diffusion model using a variational autoencoder (VAE) and a conditional U-Net. It currently targets the MNIST dataset, but the architecture and training pipeline are intentionally dataset-agnostic and can scale to larger and more complex image domains. |
| 4 | + |
| 5 | +## Features |
| 6 | +- **Modular Architecture** – separate training routines for VAE and U-Net components built on top of Hugging Face Diffusers. |
| 7 | +- **Dataset Agnostic Pipeline** – although the repository demonstrates MNIST, the data loaders and model design can be extended to other datasets with minimal changes, retaining most of the training logic from this experiment. |
| 8 | +- **Configurable Training** – hyperparameters (batch size, learning rates, epochs, etc.) are managed via `config.yaml` for reproducible experiments. |
| 9 | +- **Efficient & Stable Training** – leverages `Accelerate` for device management and distributed training, cosine learning rate schedules, and EMA tracking. Added gradient clipping. |
| 10 | +- **Visualization Utilities** – automatic saving of reconstruction and generation plots for monitoring model performance. |
| 11 | +- **Gradio App** – ready-to-deploy web app for interactive predictions. Hosted on [Huggingface Spaces](https://huggingface.co/spaces/codinglabsong/aging-gan). |
| 12 | +- **Developer Tools & CI** – Linting with ruff and black, unit tests with pytest, end‐to‐end smoke tests in GitHub Actions. |
| 13 | + |
| 14 | +## Installation |
| 15 | + |
| 16 | +1. Clone this repository and install the core dependencies: |
| 17 | + |
| 18 | + ```bash |
| 19 | + pip install -r requirements.txt |
| 20 | + ``` |
| 21 | + |
| 22 | +2. (Optional) Install development tools for linting and testing: |
| 23 | + |
| 24 | + ```bash |
| 25 | + pip install -r requirements-dev.txt |
| 26 | + pre-commit install |
| 27 | + ``` |
| 28 | + |
| 29 | +3. Install the package itself: |
| 30 | + |
| 31 | + ```bash |
| 32 | + pip install -e . |
| 33 | + ``` |
| 34 | + |
| 35 | +## Training |
| 36 | + |
| 37 | +Use the provided helper script to train both the VAE and U-Net components: |
| 38 | + |
| 39 | +```bash |
| 40 | +bash scripts/run_train.sh |
| 41 | +``` |
| 42 | + |
| 43 | +Configuration values can be adjusted in `config.yaml`. |
| 44 | + |
| 45 | +## Inference |
| 46 | + |
| 47 | +Generate samples with the trained models: |
| 48 | + |
| 49 | +```bash |
| 50 | +bash scripts/run_inference.sh |
| 51 | +``` |
| 52 | + |
| 53 | +The script will load the VAE and EMA-smoothed U-Net weights and produce images stored under `plots/unet/`. The last epoch is the generated images for inference. |
| 54 | + |
| 55 | +## Data Preprocessing |
| 56 | + |
| 57 | +- Convert images to tensors and normalize to `[-1, 1]`. |
| 58 | +- Resize images to `config.img_size` (default `32x32`). |
| 59 | +- Additional preprocessing steps can be added in `ldm/data.py` when targeting other datasets. |
| 60 | + |
| 61 | +## Results |
| 62 | +### Example Outputs |
| 63 | + |
| 64 | +*Placeholder for generated image examples.* |
| 65 | + |
| 66 | +### Considerations for Improvements |
| 67 | + |
| 68 | +- Train on higher resolution or more diverse datasets. |
| 69 | +- Experiment with deeper VAE/U-Net architectures or additional conditioning signals. |
| 70 | +- Integrate more advanced schedulers or guidance techniques. |
| 71 | + |
| 72 | +## Running the Gradio Inference App |
| 73 | + |
| 74 | +This project includes an interactive Gradio app for making predictions with the trained model. |
| 75 | + |
| 76 | +1. **Obtain the Trained Model:** |
| 77 | + - Ensure that a trained model directory (`models/vae.pth` and `models/ema-unet.pth`) is available in the project root. |
| 78 | + - If you trained the model yourself, it should be saved automatically in the project root. |
| 79 | + - Otherwise, you can download it from [Releases](https://github.com/codinglabsong/aging-gan/releases/tag/v1.0.0) and add it in the project root. |
| 80 | + |
| 81 | +2. **Run the App Locally:** |
| 82 | + ```bash |
| 83 | + python app.py |
| 84 | + ``` |
| 85 | + - Visit the printed URL (e.g., `http://127.0.0.1:7860`) to interact with the model. |
| 86 | + |
| 87 | +> You can also access the hosted demo on [Huggingface Spaces](https://huggingface.co/spaces/codinglabsong/aging-gan) |
| 88 | + |
| 89 | + |
| 90 | +## Testing |
| 91 | + |
| 92 | +Run the unit test suite with: |
| 93 | + |
| 94 | +```bash |
| 95 | +pytest |
| 96 | +``` |
| 97 | + |
| 98 | +## Repository Structure |
| 99 | + |
| 100 | +- `src/ldm/` – source code for models, training, inference, and utilities. |
| 101 | +- `scripts/` – shell scripts for training and inference. |
| 102 | +- `tests/` – unit tests verifying data loaders, models, and utilities. |
| 103 | +- `config.yaml` – experiment configuration. |
| 104 | +- `ETHICS.md`, `MODEL_CARD.md` – documentation about ethical considerations and model details. |
| 105 | + |
| 106 | +## Requirements |
| 107 | + |
| 108 | +- Python >= 3.10 |
| 109 | +- PyTorch, diffusers, and other dependencies specified in `requirements.txt` |
| 110 | + |
| 111 | +## Contributing |
| 112 | + |
| 113 | +Contributions are welcome! Please open an issue or submit a pull request. |
| 114 | + |
| 115 | +## Acknowledgements |
| 116 | + |
| 117 | +- [Hugging Face Diffusers](https://github.com/huggingface/diffusers) |
| 118 | +- [MNIST Dataset](http://yann.lecun.com/exdb/mnist/) |
| 119 | +- [High-Resolution Image Synthesis with Latent Diffusion Models Research Paper](https://arxiv.org/abs/2112.10752) |
| 120 | + |
| 121 | +## License |
| 122 | + |
| 123 | +This project is licensed under the [MIT License](LICENSE). |
0 commit comments