This repository contains code for our SIGGRAPH'22 paper "StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets"
by Axel Sauer, Katja Schwarz, and Andreas Geiger.
If you find our code or paper useful, please cite
@InProceedings{Sauer2021ARXIV,
author = {Axel Sauer and Katja Schwarz and Andreas Geiger},
title = {StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets},
journal = {arXiv.org},
volume = {abs/2201.00273},
year = {2022},
url = {https://arxiv.org/abs/2201.00273},
}
Rank on Papers With Code | |
---|---|
- Projected GANs Converge Faster (NeurIPS'21) - Official Repo -
- StyleGAN-XL + CLIP (Implemented by CasualGANPapers) -
- StyleGAN-XL + CLIP (Modified by Katherine Crowson to optimize in W+ space) -
- 64-bit Python 3.8 and PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
- CUDA toolkit 11.1 or later.
- GCC 7 or later compilers. The recommended GCC version depends on your CUDA version; see for example, CUDA 11.4 system requirements.
- If you run into problems when setting up the custom CUDA kernels, we refer to the Troubleshooting docs of the original StyleGAN3 repo and the following issues: autonomousvision#23.
- Windows user struggling installing the env might find autonomousvision#10 helpful.
- Use the following commands with Miniconda3 to create and activate your PG Python environment:
conda env create -f environment.yml
conda activate sgxl
For a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run
python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
--resolution=256x256 --transform=center-crop
You need to follow our progressive growing scheme to get the best results. Therefore, you should prepare separate zips for each training resolution. You can get the datasets we used in our paper at their respective websites (FFHQ, ImageNet).
For progressive growing, we train a stem on low resolution, e.g., 162 pixels. When the stem is finished, i.e., FID is saturating, you can start training the upper stages; we refer to these as superresolution stages.
Training StyleGAN-XL on Pokemon using 8 GPUs:
python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon16.zip \
--gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10
--batch
specifies the overall batch size, --batch-gpu
specifies the batch size per GPU. The training loop will automatically accumulate gradients if you use fewer GPUs until the overall batch size is reached.
Samples and metrics are saved in outdir
. If you don't want to track metrics, set --metrics=none
. You can inspect fid50k_full.json or run tensorboard in training-runs/
to monitor the training progress.
For a class-conditional dataset (ImageNet, CIFAR-10), add the flag --cond True
. The dataset needs to contain the class labels; see the StyleGAN2-ADA repo on how to prepare class-conditional datasets.
Continuing with pretrained stem:
python train.py --outdir=./training-runs/pokemon --cfg=stylegan3-t --data=./data/pokemon32.zip \
--gpus=8 --batch=64 --mirror=1 --snap 10 --batch-gpu 8 --kimg 10000 --syn_layers 10 \
--superres --up_factor 2 --head_layers 7 \
--path_stem training-runs/pokemon/00000-stylegan3-t-pokemon16-gpus8-batch64/best_model.pkl
--up_factor
allows to train several stages at once, i.e., with --up_factor=4
and a 162 stem you can directly train at resolution 642.
If you have enough compute, a good tactic is to train several stages in parallel and then restart the superresolution stage training once in a while. The current stage will then reload its previous stem's best_model.pkl
. Performance can sometimes drop at first because of domain shift, but the superresolution stage quickly recovers and improves further.
The default settings are tuned for ImageNet. For smaller datasets (<50k images) or well-curated datasets (FFHQ), you can significantly decrease the model size enabling much faster training. Recommended settings are: --cbase 16384 --cmax 256 --syn_layers 7
and for superresolution stages --head_layers 4
.
Suppose you want to train as few stages as possible. We recommend training a 32x32 or 64x64 stem, then directly scaling to the final resolution (as described above, you must adjust --up_factor
accordingly). However, generally, progressive growing yields better results faster as the throughput is much higher at lower resolutions. This can be seen in this figure by Karras et al., 2017:
To generate samples and interpolation videos, run
python gen_images.py --outdir=out --trunc=0.7 --seeds=10-15 --batch-sz 1 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
and
python gen_video.py --output=lerp.mp4 --trunc=0.7 --seeds=0-31 --grid=4x2 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl
For class-conditional models, you can pass the class index via --class
, a index-to-label dictionary for Imagenet can be found here. For interpolation between classes, provide, e.g., --cls=0-31
to gen_video.py
. The list of classes has to be the same length as --seeds
.
To generate a conditional sample sheet, run
python gen_class_samplesheet.py --outdir=sample_sheets --trunc=1.0 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
--samples-per-class 4 --classes 0-32 --grid-width 32
For ImageNet models, we enable multi-modal truncation (proposed by Self-Distilled GAN). We generated 600k find 10k cluster centroids via k-means. For a given samples, multi-modal truncation finds the closest centroids and interpolates towards it. To switch from uni-model to multi-modal truncation, pass
--centroids-path=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet_centroids.npy
No Truncation | Uni-Modal Truncation | Multi-Modal Truncation |
---|---|---|
To invert a given image via latent optimization, and optionally use our reimplementation of Pivotal Tuning Inversion, run
python run_inversion.py --outdir=inversion_out \
--target media/jay.png \
--inv-steps 1000 --run-pti --pti-steps 350 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl
Provide an image via target
, it is automatically resized and center-cropped to match the generator network. You do not need to provide a class for ImageNet models, we infer the class of a given sample via a pretrained classifier.
To use our reimplementation of StyleMC, and generate the example above, run
python run_stylemc.py --outdir=stylemc_out \
--text-prompt "a chimpanzee | laughter | happyness| happy chimpanzee | happy monkey | smile | grin" \
--seeds 0-256 --class-idx 367 --layers 10-30 --edit-strength 0.75 --init-seed 49 \
--network=https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl \
--bigger-network https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl
Recommended workflow:
- Sample images via
gen_images.py
. - Pick a sample and use it as the inital image for
stylemc.py
by providing--init-seed
and--class-idx
. - Find a direction in style space via
--text-prompt
. - Finetune
--edit-strength
,--layers
, and amount of--seeds
. - Once you found a good setting, provide a larger model via
--bigger-network
. The script still optimizes the direction for the smaller model, but uses the bigger model for the final output.
We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL
):
Dataset | Res | FID | PATH |
---|---|---|---|
ImageNet | 162 | 0.73 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet16.pkl |
ImageNet | 322 | 1.11 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet32.pkl |
ImageNet | 642 | 1.52 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet64.pkl |
ImageNet | 1282 | 1.77 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet128.pkl |
ImageNet | 2562 | 2.26 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl |
ImageNet | 5122 | 2.42 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet512.pkl |
ImageNet | 10242 | 2.51 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet1024.pkl |
CIFAR10 | 322 | 1.85 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/cifar10.pkl |
FFHQ | 2562 | 2.19 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq256.pkl |
FFHQ | 5122 | 2.23 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq512.pkl |
FFHQ | 10242 | 2.02 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/ffhq1024.pkl |
Pokemon | 2562 | 23.97 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon256.pkl |
Pokemon | 5122 | 23.82 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon512.pkl |
Pokemon | 10242 | 25.47 | https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/pokemon1024.pkl |
Per default, train.py
tracks FID50k during training. To calculate metrics for a specific network snapshot, run
python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL
To see the available metrics, run
python calc_metrics.py --help
We provide precomputed FID statistics for all pretrained models:
wget https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/gan-metrics.zip
unzip gan-metrics.zip -d dnnlib/
This repo builds on the codebase of StyleGAN3 and our previous project Projected GANs Converge Faster.