-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathpp_gen.py
129 lines (101 loc) · 4.12 KB
/
pp_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import os
import sys
import tempfile
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from joblib import Parallel, delayed
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.utils import save_image
from tqdm.auto import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from hair_swap import get_parser, HairFast
from scripts.pp_train import Trainer
from utils.train import seed_everything
from utils.image_utils import list_image_files
class ImageException(Exception):
def __init__(self, image, message="Return image before PP"):
self.image = image
self.message = message
super().__init__(self.message)
def hairfast_wo_pp(hair_fast):
class RaiseDownsample(nn.Module):
def __init__(self):
super().__init__()
def forward(self, image):
image = ((image[0] + 1) / 2).clip(0, 1)
raise ImageException(image)
def blend_images(func):
def wrapper(*args, **kwargs):
try:
func(*args, **kwargs)
except ImageException as e:
return e.image
return wrapper
hair_fast.blend.downsample_256 = RaiseDownsample()
hair_fast.blend.blend_images = blend_images(hair_fast.blend.blend_images)
def load_image(path):
return T.functional.to_tensor(Image.open(path))
@torch.no_grad()
def load_dataset_images(imgs, dataset_path):
net_trainer = Trainer()
source_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(args.FFHQ, img[0])) for img in tqdm(imgs))
target_images = Parallel(n_jobs=-1)(delayed(load_image)(os.path.join(dataset_path, img[1])) for img in tqdm(imgs))
tensors_dataloader = DataLoader(list(zip(source_images, target_images)), batch_size=64, pin_memory=False,
shuffle=False, drop_last=False)
source_files = [os.path.join(args.FFHQ, img[0]) for img in imgs]
data = []
total = 0
for batch in tqdm(tensors_dataloader):
source, target = [elem.to('cuda') for elem in batch]
HS_D, _ = net_trainer.generate_mask(source)
HT_D, HT_E = net_trainer.generate_mask(target)
target_mask = (1 - HS_D) * (1 - HT_D)
data.extend(list(zip(
[source_files[total + i] for i in range(len(source))],
net_trainer.downsample_256(target).clip(0, 1).cpu(),
target_mask.cpu(),
HT_E.cpu(),
)))
total += len(source)
return data
def main(args):
seed_everything(args.seed)
# init HairFast
model_parser = get_parser()
model_args = model_parser.parse_args([])
hair_fast = HairFast(model_args)
hairfast_wo_pp(hair_fast)
# generate dataset
os.makedirs(args.output, exist_ok=True)
images = list_image_files(args.FFHQ)
face, shape, color = np.array_split(np.random.choice(images, size=3 * args.size), 3)
exps = []
for exp in zip(face, shape, color):
imgs = map(lambda im: im.split('.')[0], exp)
exps.append([exp[0], f"{'_'.join(imgs)}.png", exp])
batch = 5_000
left, right, idx = 0, min(len(exps), batch), 1
while left < len(exps):
with tempfile.TemporaryDirectory() as temp_dir:
for exp in tqdm(exps[left:right]):
im1, im2, im3 = exp[-1]
image = hair_fast(args.FFHQ / im1, args.FFHQ / im2, args.FFHQ / im3)
save_image(image, os.path.join(temp_dir, exp[1]))
batch_data = load_dataset_images(exps[left:right], temp_dir)
torch.save(batch_data, os.path.join(args.output, f'pp_part_{idx}.dataset'))
left = right
right = min(len(exps), right + batch)
idx += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Blending dataset')
parser.add_argument('--FFHQ', type=Path)
parser.add_argument('--seed', type=int, default=3407)
parser.add_argument('--size', type=int, default=10_000)
parser.add_argument('--output', type=Path, default='input/pp_dataset')
args = parser.parse_args()
main(args)