This repo is the official Pytorch implementation for the paper Pseudo Label-Guided Model Inversion Attack via Conditional Generative Adversarial Network (AAAI 2023 Oral).
Install the environment as follows:
# create conda environment
conda create -n PLG_MI python=3.9
conda activate PLG_MI
# install pytorch
conda install pytorch==1.10.0 torchvision==0.10.0 cudatoolkit=11.3 -c pytorch -c conda-forge
# install other dependencies
pip install -r requirements.txt
-
CelebA , FFHQ and FaceScrub are used for expriments (we use this script to download FaceScrub and some links are unavailable.)
-
We follow the KED-MI to divide the CelebA into the private data and the public data. The private data of CelebA can be found at: https://drive.google.com/drive/folders/1uxSsbNwCKZcy3MQ4mA9rpwiJRhtpTas6?usp=sharing
-
You should put them as follows:
datasets ├── celeba │ └── img_align_celeba ├── facescrub │ └── faceScrub ├── ffhq │ └── thumbnails128x128 └── celeba_private_domain
-
You can train target models following KED-MI or direcly download the provided checkpoints at: https://drive.google.com/drive/folders/1Cf2O2MVvveXrBcdBEWDi-cMGzk0y_AsT?usp=sharing and put them in folder
./checkpoints
. -
To calculate the KNN_dist, we get the features of private data on the evaluation model in advance. You can download at: https://drive.google.com/drive/folders/1Aj9glrxLoVlfrehCX2L9weFBx5PK6z-x?usp=sharing and put them in folder
./celeba_private_feats
.
To get the pseudo-labeled public data using top-n selection strategy, pealse run the top_n_selection.py
as follows:
python top_n_selection.py --model=VGG16 --data_name=ffhq --top_n=30 --save_root=reclassified_public_data
To train the conditional GAN in stage-1, please run the train_cgan.py
as follows:
python train_cgan.py \
--data_name=ffhq \
--target_model=VGG16 \
--calc_FID \
--inv_loss_type=margin \
--max_iteration=30000 \
--alpha=0.2 \
--private_data_root=./datasets/celeba_private_domain \
--data_root=./reclassified_public_data/ffhq/VGG16_top30 \
--results_root=PLG_MI_Results
The checkpoints can be found at: https://drive.google.com/drive/folders/1qDvl7i6_U7xoaduUbeEzTSuYXWpxvvXt?usp=sharing
(All checkpoints of PLG-MI can be found at: https://drive.google.com/drive/folders/1AVdJ0ZrrW9iutCh-zrCKVkLuzD6OZGB6?usp=sharing)
To reconstruct the private images of specified class using the trained generator, pealse run the reconstruct.py
as
follows:
python reconstruct.py \
--model=VGG16 \
--inv_loss_type=margin \
--lr=0.1 \
--iter_times=600 \
--path_G=./PLG_MI_Results/ffhq/VGG16/gen_latest.pth.tar \
--save_dir=PLG_MI_Inversion
If you find this repository useful for your work, please consider citing it as follows:
@article{yuan2023pseudo,
title={Pseudo Label-Guided Model Inversion Attack via Conditional Generative Adversarial Network},
author={Yuan, Xiaojian and Chen, Kejiang and Zhang, Jie and Zhang, Weiming and Yu, Nenghai and Zhang, Yang},
journal={Proceedings of the AAAI Conference on Artificial Intelligence},
volume={37},
number={3},
pages={3349-3357},
year={2023}
}