Ever wanted fast diffusion on device? Struggled with compatibility and libraries? Worry no more—favicon diffusor is here! Using WebGPU, its supported on almost any device (that can run chrome) and can diffuse hippos anywhere (even as a favicon)!
A quick weekend project where I hacked on a bunch of WebGPU kernels from scratch and tried to optimize them. Building on my last "from scratch DiT" this starts at the kernel level and rewrites diffusion transformers using WGSL. A subsecond 32-step diffusion inference time allows for an awesome demo of actually diffusing the favicon of a website realtime in ~0.7s with a ~11M parameter model
https://notebook.neelr.dev/stories/in-browser-favicon-diffusion-scratch-dit-pt-2
Of course.... here are the approximate numbers on an M1 Pro! Currently faster than tf.js and (of course) baseline JS—transformers.js doesn't support custom layer building, so I didn't include it.
Implementation | Time (s) | vs Baseline | vs TensorFlow.js |
---|---|---|---|
Favicon Diffusor | 0.86 | 88.6% faster | 45.2% faster |
TensorFlow.js | 1.57 | 79.3% faster | baseline |
Baseline JS | 7.57 | baseline | 382% slower |
The implementation includes several key optimizations:
- Custom WGSL shaders for core operations
- Efficient memory management and tensor operations
- Optimized attention mechanisms
- Streamlined data pipelining
- you need a browser with WebGPU support (Chrome Canary or other modern browsers with WebGPU flags enabled)
- clone the repository:
git clone https://github.com/neelr/favicon-diffusor.git
cd favicon-diffusor
- Run a development server:
npx http-server
The project structure includes:
dit.py
- PyTorch reference implementationdit.js
- JavaScript implementationshaders/
- WebGPU shader implementationstrain.py
- Training scriptscompile.sh
- Compile the shaders into a single file- Various utility and testing files
Open to contributions! dit.js
and shaders/shaders.js
are the only files you really need for the demo and the rest are just for training and testing. Those two combined are only ~2k lines of code.
- implement patchify and unpatchify as shaders
- modularize all shaders into separate files
- create benchmarks against relevant other categories
- add transpose matmul optimization
- implement flashattention from scratch
- implement multi-head attention
- try implementing a "next scale prediction" VAR https://arxiv.org/abs/2404.02905
- port over a full stable diffusion checkpoint
- add text latents + possibly conditioning?
- create an easy porting script