This implementation is adapted from the stylegan2 codebase by Matthias Wright.
Specifically, the features we've added allow for better scaling of StyleGAN2 training on TPUs:
- ๐ญ Enable data-parallel training on TPU pods (tested on TPU v2 to v4 generations)
- ๐พ Google Cloud Storage (GCS) integration/dataset sharding between workers
- ๐ Quality-of-life improvements (e.g. improved W&B logging)
This research is part of the technology underlying our AI-generated photography platform Nyx.gallery
This food does not exist! Click to see more samples ๐ช๐ฐ๐ฃ๐น๐
v0.2
- Better support for class-conditional training, adding per-class moving average statistics to generator
- Training data can now be split into multiple tfrecord files (can be either in
--data_dir
or in a subdirectorytfrecords
). Still requiresdataset_info.json
in--data_dir
location (containingwidth
,heigh
,num_examples
, and list ofclasses
if class-conditional). - Renaming arg
--load_from_pkl
=>--load_from_ckpt
- Added
--num_steps
argument to specify a fixed number of steps to run - Added
--early_stopping_after_steps
argument to stop after n steps of no FID improvement - Removal of
--bf16
flag and consolidation with--mixed_precision
. - Allow layer freezing with
--freeze_g
and--freeze_d
arguments - Add
--fmap_max
argument, in order to have better control over feature map dimensions - Allow disabling of generator and discriminator regularization
- Change checkpointing behaviour from saving every 2k steps to saving every 10k steps and keeping 2 best checkpoints (see
--save_every
and--keep_n_checkpoints
) - Add
--metric_cache_location
in order to cache dataset statistics (currently for FID only) - Log TPU memory usage, shoutout to ayaka14732 for help (see also https://github.com/ayaka14732/jax-smi)
- Visualise model architecture & parameters on startup
- Improve W&B logging (e.g. adding eval snapshots with fixed latents)
- Experimental: Add jax profiling
v0.1
- Enable training on TPUs
- Google Cloud Storage (GCS) integration
- Several quality-of-life improvements
- Clone the repository:
git clone https://github.com/nyx-ai/stylegan2-flax-tpu.git
- Go into the directory:
cd stylegan2-flax-tpu
- Install requirements:
pip install -r requirements.txt
We released four 256x256 as well as 512x512 models. Download them from the latest release.
python generate_images.py \
--checkpoint checkpoints/cookie-256.pkl \
--seeds 0 42 420 666 \
--truncation_psi 0.7 \
--out_path generated_images
Check the Colab notebook for more examples:
Add your images into a folder /path/to/image_dir
:
/path/to/image_dir/
0.jpg
1.jpg
2.jpg
4.jpg
...
and create a TFRecord dataset:
python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
For more detailed instructions please refer to this README.
The following command trains with 128 resolution and batch size of 8.
python main.py --data_dir /path/to/tfrecord
Read more about suitable training parameters here.
Our experiments have been run and tested on TPU VMs (generation v2 to v4). At the time of writing Colab is offering an older generation of TPUs. Therefore training (and especially compilation) may be significantly slower. If you still wish to train on Colab, the following may get you started:
- This work is based on Matthias Wright's stylegan2 implementation.
- The project received generous support from Google's TPU Research Cloud (TRC).
- The image datasets were built using the LAION5B index
- We are grateful to Weights & Biases for preserving our sanity