This repository provides the official PyTorch implementation for the following paper:
TransEditor: Transformer-Based Dual-Space GAN for Highly Controllable Facial Editing
Yanbo Xu*, Yueqin Yin*, Liming Jiang, Qianyi Wu, Chengyao Zheng, Chen Change Loy, Bo Dai, Wayne Wu
In CVPR 2022. (* denotes equal contribution)
Project Page | Paper
Abstract: Recent advances like StyleGAN have promoted the growth of controllable facial editing. To address its core challenge of attribute decoupling in a single latent space, attempts have been made to adopt dual-space GAN for better disentanglement of style and content representations. Nonetheless, these methods are still incompetent to obtain plausible editing results with high controllability, especially for complicated attributes. In this study, we highlight the importance of interaction in a dual-space GAN for more controllable editing. We propose TransEditor, a novel Transformer-based framework to enhance such interaction. Besides, we develop a new dual-space editing and inversion strategy to provide additional editing flexibility. Extensive experiments demonstrate the superiority of the proposed framework in image quality and editing capability, suggesting the effectiveness of TransEditor for highly controllable facial editing.
A suitable Anaconda environment named transeditor
can be created and activated with:
conda env create -f environment.yaml
conda activate transeditor
Datasets | CelebA-HQ | Flickr-Faces-HQ (FFHQ) |
---|
- You can use download.sh in StyleMapGAN to download the CelebA-HQ dataset raw images and create the LMDB dataset format, similar for the FFHQ dataset.
- The pretrained models can be downloaded from TransEditor Pretrained Models.
- The age classifier and gender classifier for the FFHQ dataset can be found at pytorch-DEX.
- The
out/
folder andpsp_out/
folder should be put under theTransEditor/
root folder, thepth/
folder should be put under theTransEditor/our_interfaceGAN/ffhq_utils/dex
folder.
To train the TransEditor network, run
python train_spatial_query.py $DATA_DIR --exp_name $EXP_NAME --batch 16 --n_sample 64 --num_region 1 --num_trans 8
For the multi-gpu distributed training, run
python -m torch.distributed.launch --nproc_per_node=$GPU_NUM --master_port $PORT_NUM train_spatial_query.py $DATA_DIR --exp_name $EXP_NAME --batch 16 --n_sample 64 --num_region 1 --num_trans 8
To train the encoder-based inversion network, run
# FFHQ
python psp_spatial_train.py $FFHQ_DATA_DIR --test_path $FFHQ_TEST_DIR --ckpt .out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --start_from_latent_avg --exp_dir $INVERSION_EXP_NAME --from_plus_space
# CelebA-HQ
python psp_spatial_train.py $CELEBA_DATA_DIR --test_path $CELEBA_TEST_DIR --ckpt ./out/transeditor_celeba/checkpoint/370000.pt --num_region 1 --num_trans 8 --start_from_latent_avg --exp_dir $INVERSION_EXP_NAME --from_plus_space
# sampled image generation
python test_spatial_query.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --sample
# interpolation
python test_spatial_query.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --dat_interp
We provide two kinds of inversion methods.
Encoder-based inversion
# FFHQ
python dual_space_encoder_test.py --checkpoint_path ./psp_out/transeditor_inversion_ffhq/checkpoints/best_model.pt --output_dir ./projection --num_region 1 --num_trans 8 --start_from_latent_avg --from_plus_space --dataset_type ffhq_encode --dataset_dir /dataset/ffhq/test/images
# CelebA-HQ
python dual_space_encoder_test.py --checkpoint_path ./psp_out/transeditor_inversion_celeba/checkpoints/best_model.pt --output_dir ./projection --num_region 1 --num_trans 8 --start_from_latent_avg --from_plus_space --dataset_type celebahq_encode --dataset_dir /dataset/celeba_hq/test/images
Optimization-based inversion
# FFHQ
python projector_optimization.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --dataset_dir /dataset/ffhq/test/images --step 10000
# CelebA-HQ
python projector_optimization.py --ckpt ./out/transeditor_celeba/checkpoint/370000.pt --num_region 1 --num_trans 8 --dataset_dir /dataset/celeba_hq/test/images --step 10000
- The attribute classifiers for CelebA-HQ datasets can be found in celebahq-classifiers.
- Rename the folder as
pth_celeba
and put it under theour_interfaceGAN/celeba_utils/
folder.
CelebA_Attributes | attribute_index |
---|---|
Male | 0 |
Smiling | 1 |
Wavy hair | 3 |
Bald | 8 |
Bangs | 9 |
Black hair | 12 |
Blond hair | 13 |
For sampled image editing, run
# FFHQ
python our_interfaceGAN/edit_all_noinversion_ffhq.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --attribute_name pose --num_sample 150000 # pose
python our_interfaceGAN/edit_all_noinversion_ffhq.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --attribute_name gender --num_sample 150000 # gender
# CelebA-HQ
python our_interfaceGAN/edit_all_noinversion_celebahq.py --ckpt ./out/transeditor_celeba/checkpoint/370000.pt --attribute_index 0 --num_sample 150000 # Male
python our_interfaceGAN/edit_all_noinversion_celebahq.py --ckpt ./out/transeditor_celeba/checkpoint/370000.pt --attribute_index 3 --num_sample 150000 # wavy hair
python our_interfaceGAN/edit_all_noinversion_celebahq.py --ckpt ./out/transeditor_celeba/checkpoint/370000.pt --attribute_name pose --num_sample 150000 # pose
For real image editing, run
# FFHQ
python our_interfaceGAN/edit_all_inversion_ffhq.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --attribute_name pose --z_latent ./projection/encoder_inversion/ffhq_encode/encoded_z.npy --p_latent ./projection/encoder_inversion/ffhq_encode/encoded_p.npy # pose
python our_interfaceGAN/edit_all_inversion_ffhq.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --attribute_name gender --z_latent ./projection/encoder_inversion/ffhq_encode/encoded_z.npy --p_latent ./projection/encoder_inversion/ffhq_encode/encoded_p.npy # gender
# CelebA-HQ
python our_interfaceGAN/edit_all_inversion_celebahq.py --ckpt ./out/transeditor_celeba/checkpoint/370000.pt --attribute_index 0 --z_latent ./projection/encoder_inversion/celebahq_encode/encoded_z.npy --p_latent ./projection/encoder_inversion/celebahq_encode/encoded_p.npy # Male
# calculate fid, lpips, ppl
python metrics/evaluate_query.py --ckpt ./out/transeditor_ffhq/checkpoint/790000.pt --num_region 1 --num_trans 8 --batch 64 --inception metrics/inception_ffhq.pkl --truncation 1 --ppl --lpips --fid
interp_p_celeba
interp_z_celeba
edit_pose_ffhq
edit_gender_ffhq
edit_smile_celebahq
edit_blackhair_celebahq
If you find this work useful for your research, please cite our paper:
@inproceedings{xu2022transeditor,
title={{TransEditor}: Transformer-Based Dual-Space {GAN} for Highly Controllable Facial Editing},
author={Xu, Yanbo and Yin, Yueqin and Jiang, Liming and Wu, Qianyi and Zheng, Chengyao and Loy, Chen Change and Dai, Bo and Wu, Wayne},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2022}
}
The code is developed based on TransStyleGAN. We appreciate the nice PyTorch implementation.