forked from neuralchen/SimSwap
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6a2e03d
commit b4ba933
Showing
3 changed files
with
322 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"name": "train.ipynb", | ||
"provenance": [], | ||
"collapsed_sections": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
}, | ||
"accelerator": "GPU" | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"#Training Demo\n", | ||
"This is a simple example for training the SimSwap 224*224 with VGGFace2-224.\n", | ||
"\n", | ||
"Code path: https://github.com/neuralchen/SimSwap\n", | ||
"If you like the SimSwap project, please star it!\n", | ||
"Paper path: https://arxiv.org/pdf/2106.06340v1.pdf or https://dl.acm.org/doi/10.1145/3394171.3413630" | ||
], | ||
"metadata": { | ||
"id": "fC7QoKePuJWu" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!nvidia-smi" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "J8WrNaQHuUGC", | ||
"outputId": "afffa0be-92b5-4133-b6d9-6c3e08c6de64" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Thu Apr 21 16:07:35 2022 \n", | ||
"+-----------------------------------------------------------------------------+\n", | ||
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n", | ||
"|-------------------------------+----------------------+----------------------+\n", | ||
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", | ||
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", | ||
"| | | MIG M. |\n", | ||
"|===============================+======================+======================|\n", | ||
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n", | ||
"| N/A 67C P8 32W / 149W | 0MiB / 11441MiB | 0% Default |\n", | ||
"| | | N/A |\n", | ||
"+-------------------------------+----------------------+----------------------+\n", | ||
" \n", | ||
"+-----------------------------------------------------------------------------+\n", | ||
"| Processes: |\n", | ||
"| GPU GI CI PID Type Process name GPU Memory |\n", | ||
"| ID ID Usage |\n", | ||
"|=============================================================================|\n", | ||
"| No running processes found |\n", | ||
"+-----------------------------------------------------------------------------+\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"Installation\n", | ||
"All file changes made by this notebook are temporary. You can try to mount your own google drive to store files if you want." | ||
], | ||
"metadata": { | ||
"id": "Z6BtQIgWuoqt" | ||
} | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"#Get Scripts" | ||
], | ||
"metadata": { | ||
"id": "wdQJ9d8N8Tnf" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!git clone https://github.com/neuralchen/SimSwap\n", | ||
"!cd SimSwap && git pull" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "9jZWwt97uvIe", | ||
"outputId": "42a1bda8-3ca3-46af-fc82-d1af99ce15e1" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Cloning into 'SimSwap'...\n", | ||
"remote: Enumerating objects: 1017, done.\u001b[K\n", | ||
"remote: Counting objects: 100% (16/16), done.\u001b[K\n", | ||
"remote: Compressing objects: 100% (13/13), done.\u001b[K\n", | ||
"remote: Total 1017 (delta 5), reused 10 (delta 3), pack-reused 1001\u001b[K\n", | ||
"Receiving objects: 100% (1017/1017), 210.79 MiB | 14.80 MiB/s, done.\n", | ||
"Resolving deltas: 100% (510/510), done.\n", | ||
"Already up to date.\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"# Install Blocks" | ||
], | ||
"metadata": { | ||
"id": "ATLrrbso8Y-Y" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"!pip install googledrivedownloader\n", | ||
"!pip install timm\n", | ||
"!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "rwvbPhtOvZAL", | ||
"outputId": "ffa12208-d388-412d-e83b-c54864c4526e" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Requirement already satisfied: googledrivedownloader in /usr/local/lib/python3.7/dist-packages (0.4)\n", | ||
"Requirement already satisfied: imageio==2.4.1 in /usr/local/lib/python3.7/dist-packages (2.4.1)\n", | ||
"Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (7.1.2)\n", | ||
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (1.21.6)\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"#Download the Training Dataset\n", | ||
"We employ the cropped VGGFace2-224 dataset for this toy training demo.\n", | ||
"You can download the dataset from our google driver " | ||
], | ||
"metadata": { | ||
"id": "hleVtHIJ_QUK" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from google_drive_downloader import GoogleDriveDownloader as gdd\n", | ||
"gdd.download_file_from_google_drive(file_id='1iytA1n2z4go3uVCwE__vIKouTKyIDjEq',dest_path='/content/TrainingData/vggface2_crop_arcfacealign_224.tar',showsize=True)\n", | ||
"!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "gMVKEej59LX9", | ||
"outputId": "2e508c44-d006-4183-81d9-f9753d08dea7" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Downloading 1iytA1n2z4go3uVCwE__vIKouTKyIDjEq into /content/TrainingData/mnist.zip... \n", | ||
"0.0 B Done.\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"source": [ | ||
"#Trainig\n", | ||
"Batch size must larger than 1!" | ||
], | ||
"metadata": { | ||
"id": "o5SNDWzA8LjJ" | ||
} | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"%cd /content/SimSwap\n", | ||
"!ls\n", | ||
"!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "XCxHa4oW507s", | ||
"outputId": "c84c52d9-0b36-4932-925d-1ae38a3f7bb0" | ||
}, | ||
"execution_count": null, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"/content/SimSwap\n", | ||
" arcface_model\t predict.py\n", | ||
" cog.yaml\t README.md\n", | ||
" crop_224\t 'SimSwap colab.ipynb'\n", | ||
" data\t\t simswaplogo\n", | ||
" demo_file\t test_one_image.py\n", | ||
" docs\t\t test_video_swapmulti.py\n", | ||
" download-weights.sh test_video_swap_multispecific.py\n", | ||
" insightface_func test_video_swapsingle.py\n", | ||
" LICENSE\t test_video_swapspecific.py\n", | ||
" models\t\t test_wholeimage_swapmulti.py\n", | ||
" MultiSpecific.ipynb test_wholeimage_swap_multispecific.py\n", | ||
" options\t test_wholeimage_swapsingle.py\n", | ||
" output\t\t test_wholeimage_swapspecific.py\n", | ||
" parsing_model\t train.py\n", | ||
" pg_modules\t util\n", | ||
"------------ Options -------------\n", | ||
"Arc_path: arcface_model/arcface_checkpoint.tar\n", | ||
"Gdeep: False\n", | ||
"batchSize: 2\n", | ||
"beta1: 0.0\n", | ||
"checkpoints_dir: ./checkpoints\n", | ||
"continue_train: False\n", | ||
"dataset: /path/to/VGGFace2\n", | ||
"gpu_ids: 0\n", | ||
"isTrain: True\n", | ||
"lambda_feat: 10.0\n", | ||
"lambda_id: 30.0\n", | ||
"lambda_rec: 10.0\n", | ||
"load_pretrain: checkpoints\n", | ||
"log_frep: 200\n", | ||
"lr: 0.0004\n", | ||
"model_freq: 10000\n", | ||
"name: simswap\n", | ||
"niter: 10000\n", | ||
"niter_decay: 10000\n", | ||
"phase: train\n", | ||
"sample_freq: 1000\n", | ||
"tag: simswap\n", | ||
"total_step: 1000000\n", | ||
"train_simswap: True\n", | ||
"use_tensorboard: False\n", | ||
"which_epoch: 800000\n", | ||
"-------------- End ----------------\n", | ||
"GPU used : 0\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.PReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AdaptiveAvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Sigmoid' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm1d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n", | ||
" warnings.warn(msg, SourceChangeWarning)\n", | ||
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth\" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite0-0aa007d2.pth\n", | ||
"processing Swapping dataset images...\n", | ||
"Finished preprocessing the Swapping dataset, total dirs number: 0...\n", | ||
"Traceback (most recent call last):\n", | ||
" File \"train.py\", line 163, in <module>\n", | ||
" train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)\n", | ||
" File \"/content/SimSwap/data/data_loader_Swapping.py\", line 119, in GetLoader\n", | ||
" drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)\n", | ||
" File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\", line 268, in __init__\n", | ||
" sampler = RandomSampler(dataset, generator=generator)\n", | ||
" File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py\", line 103, in __init__\n", | ||
" \"value, but got num_samples={}\".format(self.num_samples))\n", | ||
"ValueError: num_samples should be a positive integer value, but got num_samples=0\n" | ||
] | ||
} | ||
] | ||
} | ||
] | ||
} |
Oops, something went wrong.