Skip to content

Commit

Permalink
training scripts released
Browse files Browse the repository at this point in the history
  • Loading branch information
neuralchen committed Apr 20, 2022
1 parent 9492873 commit f48dc8c
Show file tree
Hide file tree
Showing 16 changed files with 1,688 additions and 3 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,7 @@ checkpoints/
*.zip
*.avi
*.pdf
*.pptx
*.pptx

*.pth
*.onnx
127 changes: 127 additions & 0 deletions data/data_loader_Swapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import glob
import torch
import random
from PIL import Image
from pathlib import Path
from torch.utils import data
from torchvision import transforms as T
# from StyleResize import StyleResize

class data_prefetcher():
def __init__(self, loader):
self.loader = loader
self.dataiter = iter(loader)
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1,3,1,1)
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.mean = self.mean.half()
# self.std = self.std.half()
self.num_images = len(loader)
self.preload()

def preload(self):
try:
self.src_image1, self.src_image2 = next(self.dataiter)
except StopIteration:
self.dataiter = iter(self.loader)
self.src_image1, self.src_image2 = next(self.dataiter)

with torch.cuda.stream(self.stream):
self.src_image1 = self.src_image1.cuda(non_blocking=True)
self.src_image1 = self.src_image1.sub_(self.mean).div_(self.std)
self.src_image2 = self.src_image2.cuda(non_blocking=True)
self.src_image2 = self.src_image2.sub_(self.mean).div_(self.std)

def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
src_image1 = self.src_image1
src_image2 = self.src_image2
self.preload()
return src_image1, src_image2

def __len__(self):
"""Return the number of images."""
return self.num_images

class SwappingDataset(data.Dataset):
"""Dataset class for the Artworks dataset and content dataset."""

def __init__(self,
image_dir,
img_transform,
subffix='jpg',
random_seed=1234):
"""Initialize and preprocess the Swapping dataset."""
self.image_dir = image_dir
self.img_transform = img_transform
self.subffix = subffix
self.dataset = []
self.random_seed = random_seed
self.preprocess()
self.num_images = len(self.dataset)

def preprocess(self):
"""Preprocess the Swapping dataset."""
print("processing Swapping dataset images...")

temp_path = os.path.join(self.image_dir,'*/')
pathes = glob.glob(temp_path)
self.dataset = []
for dir_item in pathes:
join_path = glob.glob(os.path.join(dir_item,'*.jpg'))
print("processing %s"%dir_item,end='\r')
temp_list = []
for item in join_path:
temp_list.append(item)
self.dataset.append(temp_list)
random.seed(self.random_seed)
random.shuffle(self.dataset)
print('Finished preprocessing the Swapping dataset, total dirs number: %d...'%len(self.dataset))

def __getitem__(self, index):
"""Return two src domain images and two dst domain images."""
dir_tmp1 = self.dataset[index]
dir_tmp1_len = len(dir_tmp1)

filename1 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
filename2 = dir_tmp1[random.randint(0,dir_tmp1_len-1)]
image1 = self.img_transform(Image.open(filename1))
image2 = self.img_transform(Image.open(filename2))
return image1, image2

def __len__(self):
"""Return the number of images."""
return self.num_images

def GetLoader( dataset_roots,
batch_size=16,
dataloader_workers=8,
random_seed = 1234
):
"""Build and return a data loader."""

num_workers = dataloader_workers
data_root = dataset_roots
random_seed = random_seed

c_transforms = []

c_transforms.append(T.ToTensor())
c_transforms = T.Compose(c_transforms)

content_dataset = SwappingDataset(
data_root,
c_transforms,
"jpg",
random_seed)
content_data_loader = data.DataLoader(dataset=content_dataset,batch_size=batch_size,
drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)
prefetcher = data_prefetcher(content_data_loader)
return prefetcher

def denorm(x):
out = (x + 1) / 2
return out.clamp_(0, 1)
54 changes: 54 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ def get_current_errors(self):

def save(self, label):
pass

# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '{}_net_{}.pth'.format(epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if torch.cuda.is_available():
network.cuda()

def save_optim(self, network, network_label, epoch_label, gpu_ids=None):
save_filename = '{}_optim_{}.pth'.format(epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.state_dict(), save_path)

# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
Expand All @@ -63,6 +76,47 @@ def load_network(self, network, network_label, epoch_label, save_dir=''):
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v

if sys.version_info >= (3,0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()

for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])

print(sorted(not_initialized))
network.load_state_dict(model_dict)

# helper loading function that can be used by subclasses
def load_optim(self, network, network_label, epoch_label, save_dir=''):
save_filename = '%s_optim_%s.pth' % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
else:
#network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path, map_location=torch.device("cpu")))
except:
pretrained_dict = torch.load(save_path, map_location=torch.device("cpu"))
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
Expand Down
169 changes: 169 additions & 0 deletions models/fs_networks_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn


class InstanceNorm(nn.Module):
def __init__(self, epsilon=1e-8):
"""
@notice: avoid in-place ops.
https://discuss.pytorch.org/t/encounter-the-runtimeerror-one-of-the-variables-needed-for-gradient-computation-has-been-modified-by-an-inplace-operation/836/3
"""
super(InstanceNorm, self).__init__()
self.epsilon = epsilon

def forward(self, x):
x = x - torch.mean(x, (2, 3), True)
tmp = torch.mul(x, x) # or x ** 2
tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
return x * tmp

class ApplyStyle(nn.Module):
"""
@ref: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
"""
def __init__(self, latent_size, channels):
super(ApplyStyle, self).__init__()
self.linear = nn.Linear(latent_size, channels * 2)

def forward(self, x, latent):
style = self.linear(latent) # style => [batch_size, n_channels*2]
shape = [-1, 2, x.size(1), 1, 1]
style = style.view(shape) # [batch_size, 2, n_channels, ...]
#x = x * (style[:, 0] + 1.) + style[:, 1]
x = x * (style[:, 0] * 1 + 1.) + style[:, 1] * 1
return x

class ResnetBlock_Adain(nn.Module):
def __init__(self, dim, latent_size, padding_type, activation=nn.ReLU(True)):
super(ResnetBlock_Adain, self).__init__()

p = 0
conv1 = []
if padding_type == 'reflect':
conv1 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv1 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv1 += [nn.Conv2d(dim, dim, kernel_size=3, padding = p), InstanceNorm()]
self.conv1 = nn.Sequential(*conv1)
self.style1 = ApplyStyle(latent_size, dim)
self.act1 = activation

p = 0
conv2 = []
if padding_type == 'reflect':
conv2 += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv2 += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv2 += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), InstanceNorm()]
self.conv2 = nn.Sequential(*conv2)
self.style2 = ApplyStyle(latent_size, dim)


def forward(self, x, dlatents_in_slice):
y = self.conv1(x)
y = self.style1(y, dlatents_in_slice)
y = self.act1(y)
y = self.conv2(y)
y = self.style2(y, dlatents_in_slice)
out = x + y
return out



class Generator_Adain_Upsample(nn.Module):
def __init__(self, input_nc, output_nc, latent_size, n_blocks=6, deep=False,
norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert (n_blocks >= 0)
super(Generator_Adain_Upsample, self).__init__()
activation = nn.ReLU(True)
self.deep = deep

self.first_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
norm_layer(64), activation)
### downsample
self.down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
norm_layer(128), activation)
self.down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
norm_layer(256), activation)
self.down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)
if self.deep:
self.down4 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
norm_layer(512), activation)

### resnet blocks
BN = []
for i in range(n_blocks):
BN += [
ResnetBlock_Adain(512, latent_size=latent_size, padding_type=padding_type, activation=activation)]
self.BottleNeck = nn.Sequential(*BN)

if self.deep:
self.up4 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512), activation
)
self.up3 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256), activation
)
self.up2 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128), activation
)
self.up1 = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64), activation
)
self.last_layer = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, kernel_size=7, padding=0))

def forward(self, input, dlatents):
x = input # 3*224*224

skip1 = self.first_layer(x)
skip2 = self.down1(skip1)
skip3 = self.down2(skip2)
if self.deep:
skip4 = self.down3(skip3)
x = self.down4(skip4)
else:
x = self.down3(skip3)
bot = []
bot.append(x)
features = []
for i in range(len(self.BottleNeck)):
x = self.BottleNeck[i](x, dlatents)
bot.append(x)

if self.deep:
x = self.up4(x)
features.append(x)
x = self.up3(x)
features.append(x)
x = self.up2(x)
features.append(x)
x = self.up1(x)
features.append(x)
x = self.last_layer(x)
# x = (x + 1) / 2

# return x, bot, features, dlatents
return x
Loading

0 comments on commit f48dc8c

Please sign in to comment.