Diffusion transformers with Geometric Transform Attention. This codebase is built on fast-DiT and DiT. We thank their open-source contributions.
First, download and set up the repo:
git clone git@github.com:autonomousvision/gta.git
cd gta
git checkout DiT
We provide an environment.yml
file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the cudatoolkit
and pytorch-cuda
requirements from the file.
conda env create -f environment.yml
conda activate DiT
Pre-trained DiT + GTA checkpoints.
To extract ImageNet features with 1
GPUs on one node:
torchrun --nnodes=1 --nproc_per_node=1 extract_features.py --model DiT-B/2 --data-path /path/to/imagenet/train --features-path /path/to/store/features
We provide a training script for DiT in train.py
. This script can be used to train class-conditional
DiT models, but it can be easily modified to support other types of conditioning.
To launch DiT-B/2 (256x256) training with N
GPUs on one node:
# DiT + GTA
accelerate launch train.py --multi_gpu --num_processes N --model DiT-B/2 --features-path /path/to/store/features --posenc=gta --image-size=256 --results-dir=outputs/GTA --epochs=500 --ckpt-every=250000
# DiT + RoPE
accelerate launch train.py --multi_gpu --num_processes N --model DiT-B/2 --features-path /path/to/store/features --posenc=rope --image-size=256 --results-dir=outputs/RoPE --epochs=500 --ckpt-every=250000
DiT Model | Train Steps | FID-50K |
---|---|---|
B/2 (the original DiT) | 200K | 31.93 |
B/2+RoPE | 200K | 25.71 |
B/2+GTA | 200K | 25.15 (2.2% relative improvment over RoPE) |
B/2 (the original DiT) | 2.5M | 7.03 |
B/2+RoPE | 2.5M | 6.26 |
B/2+GTA | 2.5M | 5.87 (6.2% relative improvment over RoPE) |
These models were trained at 256x256 resolution; we used 4x A100s to train B/2. Here, FID is computed with 250 DDPM sampling steps, with the ema
VAE decoder and with guidance (cfg-scale=1.5
).
We include a sample_ddp.py
script which samples a large number of images from a DiT model in parallel. This script
generates a folder of samples as well as a .npz
file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. For example, to sample 50K images from our pre-trained DiT-B/2 model over N
GPUs, run:
# DiT + GTA
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-B/2 --num-fid-samples 50000 --ckpt=/path/to/checkpoint --posenc=gta
# DiT + RoPE
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py --model DiT-B/2 --num-fid-samples 50000 --ckpt=/path/to/checkpoint --posenc=rope
There are several additional options; see sample_ddp.py
for details.
export class=-1 # -1 indicates classes will be randomly sampled. Replace this with your desired class ID.
torchrun --nnodes=1 --nproc_per_node=1 sample.py --model DiT-B/2 --ckpt=/path/to/checkpoint --posenc=gta --num_samples=16 --sample-class=$class