Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from neuralchen:main #6

Merged
merged 5 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,5 @@ checkpoints/
*.pptx

*.pth
*.onnx
*.onnx
wandb/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ If you find this project useful, please star it. It is the greatest appreciation

## Top News <img width=8% src="./docs/img/new.gif"/>

**`2022-04-21`**: For resource limited users, we provide the cropped VGGFace2-224 dataset [VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing).
**`2022-04-21`**: For resource limited users, we provide the cropped VGGFace2-224 dataset [[Google Driver] VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing) [[Baidu Driver] ](https://pan.baidu.com/s/1OiwLJHVBSYB4AY2vEcfN0A) [Password: lrod].

**`2022-04-20`**: Training scripts are now available. We highly recommend that you guys train the simswap model with our released high quality dataset [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).

Expand Down Expand Up @@ -69,7 +69,7 @@ In order to ensure the normal training, the batch size must be greater than 1.

Friendly reminder, due to the difference in training settings, the user-trained model will have subtle differences in visual effects from the pre-trained model we provide.

- Train 224 models with VGGFace2 224*224 [VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing)
- Train 224 models with VGGFace2 224*224 [[Google Driver] VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing) [[Baidu Driver] ](https://pan.baidu.com/s/1OiwLJHVBSYB4AY2vEcfN0A) [Password: lrod]
```
python train.py --name simswap224_test --batchSize 4 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep False
```
Expand Down
185 changes: 112 additions & 73 deletions train.ipynb
Original file line number Diff line number Diff line change
@@ -1,102 +1,126 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "train.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "fC7QoKePuJWu"
},
"source": [
"#Training Demo\n",
"This is a simple example for training the SimSwap 224*224 with VGGFace2-224.\n",
"\n",
"Code path: https://github.com/neuralchen/SimSwap\n",
"If you like the SimSwap project, please star it!\n",
"Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630"
],
"metadata": {
"id": "fC7QoKePuJWu"
}
]
},
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"execution_count": 1,
"metadata": {
"id": "J8WrNaQHuUGC"
},
"execution_count": null,
"outputs": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Fri Apr 22 12:19:42 2022 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 456.71 Driver Version: 456.71 CUDA Version: 11.1 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name TCC/WDDM | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"|===============================+======================+======================|\n",
"| 0 TITAN Xp WDDM | 00000000:01:00.0 On | N/A |\n",
"| 23% 36C P8 15W / 250W | 1135MiB / 12288MiB | 4% Default |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| 0 N/A N/A 1232 C+G Insufficient Permissions N/A |\n",
"| 0 N/A N/A 1240 C+G Insufficient Permissions N/A |\n",
"| 0 N/A N/A 1528 C+G ...y\\ShellExperienceHost.exe N/A |\n",
"| 0 N/A N/A 7296 C+G Insufficient Permissions N/A |\n",
"| 0 N/A N/A 8280 C+G C:\\Windows\\explorer.exe N/A |\n",
"| 0 N/A N/A 9532 C+G ...artMenuExperienceHost.exe N/A |\n",
"| 0 N/A N/A 9896 C+G ...5n1h2txyewy\\SearchApp.exe N/A |\n",
"| 0 N/A N/A 11040 C+G ...2txyewy\\TextInputHost.exe N/A |\n",
"| 0 N/A N/A 11424 C+G Insufficient Permissions N/A |\n",
"| 0 N/A N/A 13112 C+G ...icrosoft VS Code\\Code.exe N/A |\n",
"| 0 N/A N/A 18720 C+G ...-2.9.15\\GitHubDesktop.exe N/A |\n",
"| 0 N/A N/A 22996 C+G ...bbwe\\Microsoft.Photos.exe N/A |\n",
"| 0 N/A N/A 23512 C+G ...me\\Application\\chrome.exe N/A |\n",
"| 0 N/A N/A 25892 C+G Insufficient Permissions N/A |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z6BtQIgWuoqt"
},
"source": [
"Installation\n",
"All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want."
],
"metadata": {
"id": "Z6BtQIgWuoqt"
}
]
},
{
"cell_type": "markdown",
"source": [
"#Get Scripts"
],
"metadata": {
"id": "wdQJ9d8N8Tnf"
}
},
"source": [
"#Get Scripts"
]
},
{
"cell_type": "code",
"source": [
"!git clone https://github.com/neuralchen/SimSwap\n",
"!cd SimSwap && git pull"
],
"execution_count": null,
"metadata": {
"id": "9jZWwt97uvIe"
},
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"!git clone https://github.com/neuralchen/SimSwap\n",
"!cd SimSwap && git pull"
]
},
{
"cell_type": "markdown",
"source": [
"# Install Blocks"
],
"metadata": {
"id": "ATLrrbso8Y-Y"
}
},
"source": [
"# Install Blocks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rwvbPhtOvZAL"
},
"outputs": [],
"source": [
"!pip install googledrivedownloader\n",
"!pip install timm\n",
"!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar"
],
"metadata": {
"id": "rwvbPhtOvZAL"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hleVtHIJ_QUK"
},
"source": [
"#Download the Training Dataset\n",
"We employ the cropped VGGFace2-224 dataset for this toy training demo.\n",
Expand All @@ -106,47 +130,62 @@
"***Please check the dataset in dir /content/TrainingData***\n",
"\n",
"***If dataset already exists in /content/TrainingData, please do not run blow scripts!***\n"
],
"metadata": {
"id": "hleVtHIJ_QUK"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h2tyjBl0Llxp"
},
"outputs": [],
"source": [
"!wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc\" -O /content/TrainingData/vggface2_crop_arcfacealign_224.tar && rm -rf /tmp/cookies.txt\n",
"%%cd /content/\n",
"%cd /content/\n",
"!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar\n",
"!rm /content/TrainingData/vggface2_crop_arcfacealign_224.tar"
],
"metadata": {
"id": "h2tyjBl0Llxp"
},
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o5SNDWzA8LjJ"
},
"source": [
"#Trainig\n",
"Batch size must larger than 1!"
],
"metadata": {
"id": "o5SNDWzA8LjJ"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XCxHa4oW507s"
},
"outputs": [],
"source": [
"%cd /content/SimSwap\n",
"!ls\n",
"!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False"
],
"metadata": {
"id": "XCxHa4oW507s"
},
"execution_count": null,
"outputs": []
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "train.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.5"
}
]
}
},
"nbformat": 4,
"nbformat_minor": 0
}
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Created Date: Monday December 27th 2021
# Author: Chen Xuanhong
# Email: chenxuanhongzju@outlook.com
# Last Modified: Friday, 22nd April 2022 12:34:40 am
# Last Modified: Friday, 22nd April 2022 10:49:26 am
# Modified By: Chen Xuanhong
# Copyright (c) 2021 Shanghai Jiao Tong University
#############################################################
Expand Down Expand Up @@ -49,9 +49,9 @@ def initialize(self):

# for training
self.parser.add_argument('--dataset', type=str, default="/path/to/VGGFace2", help='path to the face swapping dataset')
self.parser.add_argument('--continue_train', type=str2bool, default='True', help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='checkpoints', help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='320', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--continue_train', type=str2bool, default='False', help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='./checkpoints/simswap224_test', help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='10000', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--niter', type=int, default=10000, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=10000, help='# of iter to linearly decay learning rate to zero')
Expand Down