Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
neuralchen committed Jun 8, 2021
1 parent a0dab0c commit 01a8d6d
Show file tree
Hide file tree
Showing 26 changed files with 3,279 additions and 3 deletions.
17 changes: 16 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
Expand All @@ -45,6 +47,7 @@ htmlcov/
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

Expand All @@ -56,6 +59,7 @@ coverage.xml
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
Expand All @@ -80,8 +84,19 @@ ipython_config.py
# pyenv
.python-version

# celery beat schedule file
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py
Expand Down
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,16 @@
# SimSwap
A face swapping framework
# SimSwap: An Efficient Framework For High Fidelity Face Swapping
## Proceedings of the 28th ACM International Conference on Multimedia
## The official repository with Pytorch
[[Conference paper]](https://dl.acm.org/doi/10.1145/3394171.3413630)

![Results1](/doc/img/results1.PNG)
![Results2](/doc/img/results2.PNG)

Use python3.5, pytorch1.3.0


Use this command to test the face swapping between two images:

python test_one_image.py --isTrain false --name people --Arc_path models/BEST_checkpoint.tar --pic_a_path crop_224/mars.jpg --pic_b_path crop_224/ds.jpg --output_path output/

--name refers to the checkpoint name.
94 changes: 94 additions & 0 deletions data/CelebA_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from torch.utils.data import Dataset
import os
import numpy as np
import random
from torchvision import transforms
from PIL import Image
import cv2

class FaceDataSet(Dataset):
def __init__(self, dataset_path, batch_size):
super(FaceDataSet, self).__init__()



'''picture_dir_list = []
for i in range(self.people_num):
picture_dir_list.append('/data/home/renwangchen/vgg_align_224/'+self.people_list[i])
self.people_pic_list = []
for i in range(self.people_num):
pic_list = os.listdir(picture_dir_list[i])
person_pic_list = []
for j in range(len(pic_list)):
pic_dir = os.path.join(picture_dir_list[i], pic_list[j])
person_pic_list.append(pic_dir)
self.people_pic_list.append(person_pic_list)'''

pic_dir = '/data/home/renwangchen/CelebA_224/'
latent_dir = '/data/home/renwangchen/CelebA_latent/'

tmp_list = os.listdir(pic_dir)
self.pic_list = []
self.latent_list = []
for i in range(len(tmp_list)):
self.pic_list.append(pic_dir + tmp_list[i])
self.latent_list.append(latent_dir + tmp_list[i][:-3] + 'npy')

self.pic_list = self.pic_list[:29984]
'''for i in range(29984):
print(self.pic_list[i])'''
self.latent_list = self.latent_list[:29984]

self.people_num = len(self.pic_list)

self.type = 1
self.bs = batch_size
self.count = 0

self.transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def __getitem__(self, index):
p1 = random.randint(0, self.people_num - 1)
p2 = p1

if self.type == 0:
# load pictures from the same folder
pass
else:
# load pictures from different folders
p2 = p1
while p2 == p1:
p2 = random.randint(0, self.people_num - 1)

pic_id_dir = self.pic_list[p1]
pic_att_dir = self.pic_list[p2]
latent_id_dir = self.latent_list[p1]
latent_att_dir = self.latent_list[p2]

img_id = Image.open(pic_id_dir).convert('RGB')
img_id = self.transformer(img_id)
latent_id = np.load(latent_id_dir)
latent_id = latent_id / np.linalg.norm(latent_id)
latent_id = torch.from_numpy(latent_id)

img_att = Image.open(pic_att_dir).convert('RGB')
img_att = self.transformer(img_att)
latent_att = np.load(latent_att_dir)
latent_att = latent_att / np.linalg.norm(latent_att)
latent_att = torch.from_numpy(latent_att)

self.count += 1
data_type = self.type
if self.count == self.bs:
self.type = 1 - self.type
self.count = 0

return img_id, img_att, latent_id, latent_att, data_type

def __len__(self):
return len(self.pic_list)
76 changes: 76 additions & 0 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image

class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot

### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))

### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))

### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))

### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))

self.dataset_size = len(self.A_paths)

def __getitem__(self, index):
### input A (label maps)
A_path = self.A_paths[index]
A = Image.open(A_path)
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0

B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)

### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)

if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))

input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}

return input_dict

def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize

def name(self):
return 'AlignedDataset'
90 changes: 90 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random

class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

def initialize(self, opt):
pass

def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w

x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))

flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}

def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))

if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))

if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

transform_list += [transforms.ToTensor()]

if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)

def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)

def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)

def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img

def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
7 changes: 7 additions & 0 deletions data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
Binary file added doc/img/results1.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/img/results2.PNG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .models import ArcMarginModel
from .models import ResNet
from .models import IRBlock
from .models import SEBlock
Loading

0 comments on commit 01a8d6d

Please sign in to comment.