forked from neuralchen/SimSwap
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a0dab0c
commit 01a8d6d
Showing
26 changed files
with
3,279 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,16 @@ | ||
# SimSwap | ||
A face swapping framework | ||
# SimSwap: An Efficient Framework For High Fidelity Face Swapping | ||
## Proceedings of the 28th ACM International Conference on Multimedia | ||
## The official repository with Pytorch | ||
[[Conference paper]](https://dl.acm.org/doi/10.1145/3394171.3413630) | ||
|
||
![Results1](/doc/img/results1.PNG) | ||
![Results2](/doc/img/results2.PNG) | ||
|
||
Use python3.5, pytorch1.3.0 | ||
|
||
|
||
Use this command to test the face swapping between two images: | ||
|
||
python test_one_image.py --isTrain false --name people --Arc_path models/BEST_checkpoint.tar --pic_a_path crop_224/mars.jpg --pic_b_path crop_224/ds.jpg --output_path output/ | ||
|
||
--name refers to the checkpoint name. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
import os | ||
import numpy as np | ||
import random | ||
from torchvision import transforms | ||
from PIL import Image | ||
import cv2 | ||
|
||
class FaceDataSet(Dataset): | ||
def __init__(self, dataset_path, batch_size): | ||
super(FaceDataSet, self).__init__() | ||
|
||
|
||
|
||
'''picture_dir_list = [] | ||
for i in range(self.people_num): | ||
picture_dir_list.append('/data/home/renwangchen/vgg_align_224/'+self.people_list[i]) | ||
self.people_pic_list = [] | ||
for i in range(self.people_num): | ||
pic_list = os.listdir(picture_dir_list[i]) | ||
person_pic_list = [] | ||
for j in range(len(pic_list)): | ||
pic_dir = os.path.join(picture_dir_list[i], pic_list[j]) | ||
person_pic_list.append(pic_dir) | ||
self.people_pic_list.append(person_pic_list)''' | ||
|
||
pic_dir = '/data/home/renwangchen/CelebA_224/' | ||
latent_dir = '/data/home/renwangchen/CelebA_latent/' | ||
|
||
tmp_list = os.listdir(pic_dir) | ||
self.pic_list = [] | ||
self.latent_list = [] | ||
for i in range(len(tmp_list)): | ||
self.pic_list.append(pic_dir + tmp_list[i]) | ||
self.latent_list.append(latent_dir + tmp_list[i][:-3] + 'npy') | ||
|
||
self.pic_list = self.pic_list[:29984] | ||
'''for i in range(29984): | ||
print(self.pic_list[i])''' | ||
self.latent_list = self.latent_list[:29984] | ||
|
||
self.people_num = len(self.pic_list) | ||
|
||
self.type = 1 | ||
self.bs = batch_size | ||
self.count = 0 | ||
|
||
self.transformer = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
|
||
def __getitem__(self, index): | ||
p1 = random.randint(0, self.people_num - 1) | ||
p2 = p1 | ||
|
||
if self.type == 0: | ||
# load pictures from the same folder | ||
pass | ||
else: | ||
# load pictures from different folders | ||
p2 = p1 | ||
while p2 == p1: | ||
p2 = random.randint(0, self.people_num - 1) | ||
|
||
pic_id_dir = self.pic_list[p1] | ||
pic_att_dir = self.pic_list[p2] | ||
latent_id_dir = self.latent_list[p1] | ||
latent_att_dir = self.latent_list[p2] | ||
|
||
img_id = Image.open(pic_id_dir).convert('RGB') | ||
img_id = self.transformer(img_id) | ||
latent_id = np.load(latent_id_dir) | ||
latent_id = latent_id / np.linalg.norm(latent_id) | ||
latent_id = torch.from_numpy(latent_id) | ||
|
||
img_att = Image.open(pic_att_dir).convert('RGB') | ||
img_att = self.transformer(img_att) | ||
latent_att = np.load(latent_att_dir) | ||
latent_att = latent_att / np.linalg.norm(latent_att) | ||
latent_att = torch.from_numpy(latent_att) | ||
|
||
self.count += 1 | ||
data_type = self.type | ||
if self.count == self.bs: | ||
self.type = 1 - self.type | ||
self.count = 0 | ||
|
||
return img_id, img_att, latent_id, latent_att, data_type | ||
|
||
def __len__(self): | ||
return len(self.pic_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os.path | ||
from data.base_dataset import BaseDataset, get_params, get_transform, normalize | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
|
||
class AlignedDataset(BaseDataset): | ||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
|
||
### input A (label maps) | ||
dir_A = '_A' if self.opt.label_nc == 0 else '_label' | ||
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) | ||
self.A_paths = sorted(make_dataset(self.dir_A)) | ||
|
||
### input B (real images) | ||
if opt.isTrain or opt.use_encoded_image: | ||
dir_B = '_B' if self.opt.label_nc == 0 else '_img' | ||
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) | ||
self.B_paths = sorted(make_dataset(self.dir_B)) | ||
|
||
### instance maps | ||
if not opt.no_instance: | ||
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') | ||
self.inst_paths = sorted(make_dataset(self.dir_inst)) | ||
|
||
### load precomputed instance-wise encoded features | ||
if opt.load_features: | ||
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') | ||
print('----------- loading features from %s ----------' % self.dir_feat) | ||
self.feat_paths = sorted(make_dataset(self.dir_feat)) | ||
|
||
self.dataset_size = len(self.A_paths) | ||
|
||
def __getitem__(self, index): | ||
### input A (label maps) | ||
A_path = self.A_paths[index] | ||
A = Image.open(A_path) | ||
params = get_params(self.opt, A.size) | ||
if self.opt.label_nc == 0: | ||
transform_A = get_transform(self.opt, params) | ||
A_tensor = transform_A(A.convert('RGB')) | ||
else: | ||
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) | ||
A_tensor = transform_A(A) * 255.0 | ||
|
||
B_tensor = inst_tensor = feat_tensor = 0 | ||
### input B (real images) | ||
if self.opt.isTrain or self.opt.use_encoded_image: | ||
B_path = self.B_paths[index] | ||
B = Image.open(B_path).convert('RGB') | ||
transform_B = get_transform(self.opt, params) | ||
B_tensor = transform_B(B) | ||
|
||
### if using instance maps | ||
if not self.opt.no_instance: | ||
inst_path = self.inst_paths[index] | ||
inst = Image.open(inst_path) | ||
inst_tensor = transform_A(inst) | ||
|
||
if self.opt.load_features: | ||
feat_path = self.feat_paths[index] | ||
feat = Image.open(feat_path).convert('RGB') | ||
norm = normalize() | ||
feat_tensor = norm(transform_A(feat)) | ||
|
||
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, | ||
'feat': feat_tensor, 'path': A_path} | ||
|
||
return input_dict | ||
|
||
def __len__(self): | ||
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize | ||
|
||
def name(self): | ||
return 'AlignedDataset' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
import numpy as np | ||
import random | ||
|
||
class BaseDataset(data.Dataset): | ||
def __init__(self): | ||
super(BaseDataset, self).__init__() | ||
|
||
def name(self): | ||
return 'BaseDataset' | ||
|
||
def initialize(self, opt): | ||
pass | ||
|
||
def get_params(opt, size): | ||
w, h = size | ||
new_h = h | ||
new_w = w | ||
if opt.resize_or_crop == 'resize_and_crop': | ||
new_h = new_w = opt.loadSize | ||
elif opt.resize_or_crop == 'scale_width_and_crop': | ||
new_w = opt.loadSize | ||
new_h = opt.loadSize * h // w | ||
|
||
x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) | ||
y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) | ||
|
||
flip = random.random() > 0.5 | ||
return {'crop_pos': (x, y), 'flip': flip} | ||
|
||
def get_transform(opt, params, method=Image.BICUBIC, normalize=True): | ||
transform_list = [] | ||
if 'resize' in opt.resize_or_crop: | ||
osize = [opt.loadSize, opt.loadSize] | ||
transform_list.append(transforms.Scale(osize, method)) | ||
elif 'scale_width' in opt.resize_or_crop: | ||
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) | ||
|
||
if 'crop' in opt.resize_or_crop: | ||
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) | ||
|
||
if opt.resize_or_crop == 'none': | ||
base = float(2 ** opt.n_downsample_global) | ||
if opt.netG == 'local': | ||
base *= (2 ** opt.n_local_enhancers) | ||
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) | ||
|
||
if opt.isTrain and not opt.no_flip: | ||
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) | ||
|
||
transform_list += [transforms.ToTensor()] | ||
|
||
if normalize: | ||
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
return transforms.Compose(transform_list) | ||
|
||
def normalize(): | ||
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
|
||
def __make_power_2(img, base, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
h = int(round(oh / base) * base) | ||
w = int(round(ow / base) * base) | ||
if (h == oh) and (w == ow): | ||
return img | ||
return img.resize((w, h), method) | ||
|
||
def __scale_width(img, target_width, method=Image.BICUBIC): | ||
ow, oh = img.size | ||
if (ow == target_width): | ||
return img | ||
w = target_width | ||
h = int(target_width * oh / ow) | ||
return img.resize((w, h), method) | ||
|
||
def __crop(img, pos, size): | ||
ow, oh = img.size | ||
x1, y1 = pos | ||
tw = th = size | ||
if (ow > tw or oh > th): | ||
return img.crop((x1, y1, x1 + tw, y1 + th)) | ||
return img | ||
|
||
def __flip(img, flip): | ||
if flip: | ||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
|
||
def CreateDataLoader(opt): | ||
from data.custom_dataset_data_loader import CustomDatasetDataLoader | ||
data_loader = CustomDatasetDataLoader() | ||
print(data_loader.name()) | ||
data_loader.initialize(opt) | ||
return data_loader |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .models import ArcMarginModel | ||
from .models import ResNet | ||
from .models import IRBlock | ||
from .models import SEBlock |
Oops, something went wrong.