forked from neuralchen/SimSwap
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9492873
commit f48dc8c
Showing
16 changed files
with
1,688 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,4 +135,7 @@ checkpoints/ | |
*.zip | ||
*.avi | ||
*.pptx | ||
*.pptx | ||
|
||
*.pth | ||
*.onnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import os | ||
import glob | ||
import torch | ||
import random | ||
from PIL import Image | ||
from pathlib import Path | ||
from torch.utils import data | ||
from torchvision import transforms as T | ||
# from StyleResize import StyleResize | ||
|
||
class data_prefetcher(): | ||
def __init__(self, loader): | ||
self.loader = loader | ||
self.dataiter = iter(loader) | ||
self.stream = torch.cuda.Stream() | ||
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1) | ||
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1) | ||
# With Amp, it isn't necessary to manually convert data to half. | ||
# if args.fp16: | ||
# self.mean = self.mean.half() | ||
# self.std = self.std.half() | ||
self.num_images = len(loader) | ||
self.preload() | ||
|
||
def preload(self): | ||
try: | ||
self.src_image1, self.src_image2 = next(self.dataiter) | ||
except StopIteration: | ||
self.dataiter = iter(self.loader) | ||
self.src_image1, self.src_image2 = next(self.dataiter) | ||
|
||
with torch.cuda.stream(self.stream): | ||
self.src_image1 = self.src_image1.cuda(non_blocking=True) | ||
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std) | ||
self.src_image2 = self.src_image2.cuda(non_blocking=True) | ||
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std) | ||
|
||
def next(self): | ||
torch.cuda.current_stream().wait_stream(self.stream) | ||
src_image1 = self.src_image1 | ||
src_image2 = self.src_image2 | ||
self.preload() | ||
return src_image1, src_image2 | ||
|
||
def __len__(self): | ||
"""Return the number of images.""" | ||
return self.num_images | ||
|
||
class SwappingDataset(data.Dataset): | ||
"""Dataset class for the Artworks dataset and content dataset.""" | ||
|
||
def __init__(self, | ||
image_dir, | ||
img_transform, | ||
subffix='jpg', | ||
random_seed=1234): | ||
"""Initialize and preprocess the Swapping dataset.""" | ||
self.image_dir = image_dir | ||
self.img_transform = img_transform | ||
self.subffix = subffix | ||
self.dataset = [] | ||
self.random_seed = random_seed | ||
self.preprocess() | ||
self.num_images = len(self.dataset) | ||
|
||
def preprocess(self): | ||
"""Preprocess the Swapping dataset.""" | ||
print("processing Swapping dataset images...") | ||
|
||
temp_path = os.path.join(self.image_dir,'*/') | ||
pathes = glob.glob(temp_path) | ||
self.dataset = [] | ||
for dir_item in pathes: | ||
join_path = glob.glob(os.path.join(dir_item,'*.jpg')) | ||
print("processing %s"%dir_item,end='\r') | ||
temp_list = [] | ||
for item in join_path: | ||
temp_list.append(item) | ||
self.dataset.append(temp_list) | ||
random.seed(self.random_seed) | ||
random.shuffle(self.dataset) | ||
print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset)) | ||
|
||
def __getitem__(self, index): | ||
"""Return two src domain images and two dst domain images.""" | ||
dir_tmp1 = self.dataset[index] | ||
dir_tmp1_len = len(dir_tmp1) | ||
|
||
filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] | ||
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)] | ||
image1 = self.img_transform(Image.open(filename1)) | ||
image2 = self.img_transform(Image.open(filename2)) | ||
return image1, image2 | ||
|
||
def __len__(self): | ||
"""Return the number of images.""" | ||
return self.num_images | ||
|
||
def GetLoader( dataset_roots, | ||
batch_size=16, | ||
dataloader_workers=8, | ||
random_seed = 1234 | ||
): | ||
"""Build and return a data loader.""" | ||
|
||
num_workers = dataloader_workers | ||
data_root = dataset_roots | ||
random_seed = random_seed | ||
|
||
c_transforms = [] | ||
|
||
c_transforms.append(T.ToTensor()) | ||
c_transforms = T.Compose(c_transforms) | ||
|
||
content_dataset = SwappingDataset( | ||
data_root, | ||
c_transforms, | ||
"jpg", | ||
random_seed) | ||
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size, | ||
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True) | ||
prefetcher = data_prefetcher(content_data_loader) | ||
return prefetcher | ||
|
||
def denorm(x): | ||
out = (x + 1) / 2 | ||
return out.clamp_(0, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
""" | ||
Copyright (C) 2019 NVIDIA Corporation. All rights reserved. | ||
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class InstanceNorm(nn.Module): | ||
def __init__(self, epsilon=1e-8): | ||
""" | ||
@notice: avoid in-place ops. | ||
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3 | ||
""" | ||
super(InstanceNorm, self).__init__() | ||
self.epsilon = epsilon | ||
|
||
def forward(self, x): | ||
x = x - torch.mean(x, (2, 3), True) | ||
tmp = torch.mul(x, x) # or x ** 2 | ||
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon) | ||
return x * tmp | ||
|
||
class ApplyStyle(nn.Module): | ||
""" | ||
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb | ||
""" | ||
def __init__(self, latent_size, channels): | ||
super(ApplyStyle, self).__init__() | ||
self.linear = nn.Linear(latent_size, channels * 2) | ||
|
||
def forward(self, x, latent): | ||
style = self.linear(latent) # style => [batch_size, n_channels*2] | ||
shape = [-1, 2, x.size(1), 1, 1] | ||
style = style.view(shape) # [batch_size, 2, n_channels, ...] | ||
#x = x * (style[:, 0] + 1.) + style[:, 1] | ||
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1 | ||
return x | ||
|
||
class ResnetBlock_Adain(nn.Module): | ||
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)): | ||
super(ResnetBlock_Adain, self).__init__() | ||
|
||
p = 0 | ||
conv1 = [] | ||
if padding_type == 'reflect': | ||
conv1 += [nn.ReflectionPad2d(1)] | ||
elif padding_type == 'replicate': | ||
conv1 += [nn.ReplicationPad2d(1)] | ||
elif padding_type == 'zero': | ||
p = 1 | ||
else: | ||
raise NotImplementedError('padding [%s] is not implemented' % padding_type) | ||
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()] | ||
self.conv1 = nn.Sequential(*conv1) | ||
self.style1 = ApplyStyle(latent_size, dim) | ||
self.act1 = activation | ||
|
||
p = 0 | ||
conv2 = [] | ||
if padding_type == 'reflect': | ||
conv2 += [nn.ReflectionPad2d(1)] | ||
elif padding_type == 'replicate': | ||
conv2 += [nn.ReplicationPad2d(1)] | ||
elif padding_type == 'zero': | ||
p = 1 | ||
else: | ||
raise NotImplementedError('padding [%s] is not implemented' % padding_type) | ||
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()] | ||
self.conv2 = nn.Sequential(*conv2) | ||
self.style2 = ApplyStyle(latent_size, dim) | ||
|
||
|
||
def forward(self, x, dlatents_in_slice): | ||
y = self.conv1(x) | ||
y = self.style1(y, dlatents_in_slice) | ||
y = self.act1(y) | ||
y = self.conv2(y) | ||
y = self.style2(y, dlatents_in_slice) | ||
out = x + y | ||
return out | ||
|
||
|
||
|
||
class Generator_Adain_Upsample(nn.Module): | ||
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False, | ||
norm_layer=nn.BatchNorm2d, | ||
padding_type='reflect'): | ||
assert (n_blocks >= 0) | ||
super(Generator_Adain_Upsample, self).__init__() | ||
activation = nn.ReLU(True) | ||
self.deep = deep | ||
|
||
self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0), | ||
norm_layer(64), activation) | ||
### downsample | ||
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | ||
norm_layer(128), activation) | ||
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), | ||
norm_layer(256), activation) | ||
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), | ||
norm_layer(512), activation) | ||
if self.deep: | ||
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), | ||
norm_layer(512), activation) | ||
|
||
### resnet blocks | ||
BN = [] | ||
for i in range(n_blocks): | ||
BN += [ | ||
ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)] | ||
self.BottleNeck = nn.Sequential(*BN) | ||
|
||
if self.deep: | ||
self.up4 = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), | ||
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(512), activation | ||
) | ||
self.up3 = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), | ||
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(256), activation | ||
) | ||
self.up2 = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), | ||
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(128), activation | ||
) | ||
self.up1 = nn.Sequential( | ||
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False), | ||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), | ||
nn.BatchNorm2d(64), activation | ||
) | ||
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0)) | ||
|
||
def forward(self, input, dlatents): | ||
x = input # 3*224*224 | ||
|
||
skip1 = self.first_layer(x) | ||
skip2 = self.down1(skip1) | ||
skip3 = self.down2(skip2) | ||
if self.deep: | ||
skip4 = self.down3(skip3) | ||
x = self.down4(skip4) | ||
else: | ||
x = self.down3(skip3) | ||
bot = [] | ||
bot.append(x) | ||
features = [] | ||
for i in range(len(self.BottleNeck)): | ||
x = self.BottleNeck[i](x, dlatents) | ||
bot.append(x) | ||
|
||
if self.deep: | ||
x = self.up4(x) | ||
features.append(x) | ||
x = self.up3(x) | ||
features.append(x) | ||
x = self.up2(x) | ||
features.append(x) | ||
x = self.up1(x) | ||
features.append(x) | ||
x = self.last_layer(x) | ||
# x = (x + 1) / 2 | ||
|
||
# return x, bot, features, dlatents | ||
return x |
Oops, something went wrong.