forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterpolate.py
110 lines (77 loc) · 3.12 KB
/
interpolate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# # (0). Imports
from templates import *
import matplotlib.pyplot as plt
import numpy as np
import argparse
import os
def parse_args():
parser = argparse.ArgumentParser(description=globals()["__doc__"])
parser.add_argument("--Te", type=int, default=250, help="Encoder Time Steps")
parser.add_argument("--Tr", type=int, default=20, help="Render Time Steps")
args = parser.parse_args()
return args
def main():
args = parse_args()
# # (1). Directory and device
dir_pre = 'store/models/diffae/'
dir_figs = 'store/output/diffae/interpolate/'
os.makedirs(dir_figs,exist_ok=True)
#device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print(f'Using device: {device}')
if device=='cuda':
os.system('nvidia_smi')
# # (2). Setup and load in model
conf = ffhq256_autoenc()
# print(conf.name)
model = LitModel(conf)
state = torch.load(f'{dir_pre}checkpoints/{conf.name}/last.ckpt', map_location='cpu')
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device);
# # (3). Set up data
data = ImageDataset('imgs_interpolate', image_size=conf.img_size, exts=['jpg', 'JPG', 'png'], do_augment=False)
batch = torch.stack([
data[0]['img'],
data[1]['img'],
])
# plt.imshow(batch[0].permute([1, 2, 0]) / 2 + 0.5)
# plt.show()
# import IPython ; IPython.embed()
# # (4). Encode
cond = model.encode(batch.to(device))
#Te=2 # was 250 originally, made smaller to run faster on cpu
xT = model.encode_stochastic(batch.to(device), cond, T=args.Te)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ori = (batch + 1) / 2
ax[0].imshow(ori[0].permute(1, 2, 0).cpu())
ax[1].imshow(xT[0].permute(1, 2, 0).cpu())
plt.savefig(f'{dir_figs}encoding_Te{args.Te}.png')
# # (5). Interpolate
#
# Semantic codes are interpolated using convex combination, while stochastic
# codes are interpolated using spherical linear interpolation.
alpha = torch.tensor(np.linspace(0, 1, 10, dtype=np.float32)).to(cond.device)
intp = cond[0][None] * (1 - alpha[:, None]) + cond[1][None] * alpha[:, None]
import IPython; IPython.embed()
def cos(a, b):
a = a.view(-1)
b = b.view(-1)
a = F.normalize(a, dim=0)
b = F.normalize(b, dim=0)
return (a * b).sum()
theta = torch.arccos(cos(xT[0], xT[1]))
x_shape = xT[0].shape
intp_x = (torch.sin((1 - alpha[:, None]) * theta) * xT[0].flatten(0, 2)[None] +
torch.sin(alpha[:, None] * theta) * xT[1].flatten(0, 2)[None]) / torch.sin(theta)
intp_x = intp_x.view(-1, *x_shape)
#Ti = 2 # was 20, made smaller for speedup on cpu
pred = model.render(intp_x, intp, T=args.Tr)
# # (6). Plot interpolation results
# torch.manual_seed(1)
fig, ax = plt.subplots(1, 10, figsize=(5*10, 5))
for i in range(len(alpha)):
ax[i].imshow(pred[i].permute(1, 2, 0).cpu())
plt.savefig(f'{dir_figs}interpolate_Te{args.Te}_Tr{args.Tr}.png')
if __name__ == "__main__":
main()