Skip to content

Commit 68dab2a

Browse files
committed
latest stuff added
1 parent fe9410b commit 68dab2a

File tree

8 files changed

+349
-179
lines changed

8 files changed

+349
-179
lines changed

samples/.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@ GAN_GEN_SHADOW_8.pth
88
interpolation.mp4
99
video_2.gif
1010
video_3.gif
11+
12+
# ignore the new latent_space interpolation video
13+
new_interp.mp4
14+
15+
frames_pro/
16+
frames_mine/
17+
M_GAN_GEN_SHADOW_8.pth

samples/celebA-HQ.png

-34.9 MB
Binary file not shown.

samples/demo.py

Lines changed: 105 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,124 @@
1+
""" live (realtime) latent space interpolations of trained models """
2+
3+
import argparse
14
import torch as th
5+
import numpy as np
26
import matplotlib.pyplot as plt
37
from matplotlib.animation import FuncAnimation
4-
from pro_gan_pytorch import PRO_GAN as pg
5-
6-
# ==========================================================================
7-
# Tweakable parameters
8-
# ==========================================================================
9-
depth = 8
10-
num_points = 12
11-
transition_points = 30
12-
# ==========================================================================
8+
from pro_gan_pytorch.PRO_GAN import Generator
9+
from torchvision.utils import make_grid
10+
from math import ceil, sqrt
11+
from scipy.ndimage import gaussian_filter
1312

1413
# create the device for running the demo:
1514
device = th.device("cuda" if th.cuda.is_available() else "cpu")
1615

17-
# load the model for the demo
18-
gen = th.nn.DataParallel(pg.Generator(depth=9))
19-
gen.load_state_dict(th.load("GAN_GEN_SHADOW_8.pth", map_location=str(device)))
2016

17+
def parse_arguments():
18+
"""
19+
command line arguments parser
20+
:return: args => parsed command line arguments
21+
"""
22+
parser = argparse.ArgumentParser()
23+
24+
parser.add_argument("--generator_file", action="store", type=str,
25+
default=None, help="path to the trained generator model")
26+
27+
parser.add_argument("--depth", action="store", type=int,
28+
default=9, help="Depth of the network")
29+
30+
parser.add_argument("--latent_size", action="store", type=int,
31+
default=512, help="Latent size for the network")
32+
33+
parser.add_argument("--num_points", action="store", type=int,
34+
default=12, help="Number of samples to be seen")
35+
36+
parser.add_argument("--transition_points", action="store", type=int,
37+
default=30,
38+
help="Number of transition samples for interpolation")
39+
40+
parser.add_argument("--smoothing", action="store", type=float,
41+
default=1.0,
42+
help="amount of transitional smoothing")
43+
44+
args = parser.parse_args()
45+
46+
return args
47+
48+
49+
def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
50+
"""
51+
adjust the dynamic colour range of the given input data
52+
:param data: input image data
53+
:param drange_in: original range of input
54+
:param drange_out: required range of output
55+
:return: img => colour range adjusted images
56+
"""
57+
if drange_in != drange_out:
58+
scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
59+
np.float32(drange_in[1]) - np.float32(drange_in[0]))
60+
bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
61+
data = data * scale + bias
62+
return th.clamp(data, min=0, max=1)
63+
64+
65+
def get_image(gen, point, depth, alpha):
66+
"""
67+
obtain an All-resolution grid of images from the given point
68+
:param gen: the generator object
69+
:param point: random latent point for generation
70+
:param depth: value of depth for image generation (0 indexed)
71+
:param alpha: value of alpha for fade-in (between 0 and 1)
72+
:return: img => generated image
73+
"""
74+
image = gen(point, depth, alpha).detach()
75+
image = adjust_dynamic_range(image).squeeze(dim=0)
76+
return image.cpu().numpy().transpose(1, 2, 0)
2177

22-
# function to generate an image given a latent_point
23-
def get_image(point):
24-
img = gen(point, depth=depth, alpha=1).detach().squeeze(0).permute(1, 2, 0)
25-
img = (img - img.min()) / (img.max() - img.min())
26-
return img.cpu().numpy()
2778

79+
def main(args):
80+
"""
81+
Main function for the script
82+
:param args: parsed command line arguments
83+
:return: None
84+
"""
2885

29-
# generate the set of points:
30-
fixed_points = th.randn(num_points, 512).to(device)
31-
fixed_points = (fixed_points / fixed_points.norm(dim=1, keepdim=True)) * (512 ** 0.5)
32-
points = [] # start with an empty list
33-
for i in range(len(fixed_points) - 1):
34-
pt_1 = fixed_points[i].view(1, -1)
35-
pt_2 = fixed_points[i + 1].view(1, -1)
36-
direction = pt_2 - pt_1
37-
for j in range(transition_points):
38-
pt = pt_1 + ((direction / transition_points) * j)
39-
pt = (pt / pt.norm()) * (512 ** 0.5)
40-
points.append(pt)
41-
# also append the final point:
42-
points.append(pt_2)
86+
# load the model for the demo
87+
gen = th.nn.DataParallel(
88+
Generator(
89+
depth=args.depth,
90+
latent_size=args.latent_size))
91+
gen.load_state_dict(th.load(args.generator_file, map_location=str(device)))
4392

44-
start_point = points[0]
45-
points = points[1:]
93+
# generate the set of points:
94+
total_frames = args.num_points * args.transition_points
95+
all_latents = th.randn(total_frames, args.latent_size).to(device)
96+
all_latents = th.from_numpy(
97+
gaussian_filter(
98+
all_latents.cpu(),
99+
[args.smoothing * args.transition_points, 0], mode="wrap"))
100+
all_latents = (all_latents /
101+
all_latents.norm(dim=-1, keepdim=True)) * sqrt(args.latent_size)
46102

47-
fig, ax = plt.subplots()
48-
plt.axis("off")
49-
shower = plt.imshow(get_image(start_point))
103+
start_point = th.unsqueeze(all_latents[0], dim=0)
104+
points = all_latents[1:]
50105

106+
fig, ax = plt.subplots()
107+
plt.axis("off")
108+
shower = plt.imshow(get_image(gen, start_point, args.depth - 1, 1))
51109

52-
def init():
53-
return shower,
110+
def init():
111+
return shower,
54112

113+
def update(point):
114+
shower.set_data(get_image(gen, th.unsqueeze(point, dim=0), args.depth - 1, 1))
115+
return shower,
55116

56-
def update(point):
57-
shower.set_data(get_image(point))
58-
return shower,
117+
# define the animation function
118+
ani = FuncAnimation(fig, update, frames=points,
119+
init_func=init)
120+
plt.show(ani)
59121

60122

61-
ani = FuncAnimation(fig, update, frames=points,
62-
init_func=init, blit=False)
63-
plt.show()
123+
if __name__ == '__main__':
124+
main(parse_arguments())

samples/faces_sheet_1.png

70.1 MB
Loading

samples/faces_sheet_2.png

70.7 MB
Loading

samples/generate_samples.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
""" Generate single image samples from a particular depth of a model """
2+
3+
import argparse
4+
import torch as th
5+
import numpy as np
6+
import os
7+
from torch.backends import cudnn
8+
from pro_gan_pytorch.PRO_GAN import Generator
9+
from torch.nn.functional import interpolate
10+
from scipy.misc import imsave
11+
from tqdm import tqdm
12+
13+
# turn on the fast GPU processing mode on
14+
cudnn.benchmark = True
15+
16+
17+
# set the manual seed
18+
# th.manual_seed(3)
19+
20+
21+
def parse_arguments():
22+
"""
23+
default command line argument parser
24+
:return: args => parsed command line arguments
25+
"""
26+
27+
parser = argparse.ArgumentParser()
28+
29+
parser.add_argument("--generator_file", action="store", type=str,
30+
help="pretrained weights file for generator", required=True)
31+
32+
parser.add_argument("--latent_size", action="store", type=int,
33+
default=256,
34+
help="latent size for the generator")
35+
36+
parser.add_argument("--depth", action="store", type=int,
37+
default=9,
38+
help="depth of the network. **Starts from 1")
39+
40+
parser.add_argument("--out_depth", action="store", type=int,
41+
default=6,
42+
help="output depth of images. **Starts from 0")
43+
44+
parser.add_argument("--num_samples", action="store", type=int,
45+
default=300,
46+
help="number of synchronized grids to be generated")
47+
48+
parser.add_argument("--out_dir", action="store", type=str,
49+
default="interp_animation_frames/",
50+
help="path to the output directory for the frames")
51+
52+
args = parser.parse_args()
53+
54+
return args
55+
56+
57+
def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
58+
"""
59+
adjust the dynamic colour range of the given input data
60+
:param data: input image data
61+
:param drange_in: original range of input
62+
:param drange_out: required range of output
63+
:return: img => colour range adjusted images
64+
"""
65+
if drange_in != drange_out:
66+
scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
67+
np.float32(drange_in[1]) - np.float32(drange_in[0]))
68+
bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
69+
data = data * scale + bias
70+
return th.clamp(data, min=0, max=1)
71+
72+
73+
def main(args):
74+
"""
75+
Main function for the script
76+
:param args: parsed command line arguments
77+
:return: None
78+
"""
79+
80+
print("Creating generator object ...")
81+
# create the generator object
82+
gen = th.nn.DataParallel(Generator(
83+
depth=args.depth,
84+
latent_size=args.latent_size
85+
))
86+
87+
print("Loading the generator weights from:", args.generator_file)
88+
# load the weights into it
89+
gen.load_state_dict(
90+
th.load(args.generator_file)
91+
)
92+
93+
# path for saving the files:
94+
save_path = args.out_dir
95+
96+
print("Generating scale synchronized images ...")
97+
for img_num in tqdm(range(1, args.num_samples + 1)):
98+
# generate the images:
99+
with th.no_grad():
100+
point = th.randn(1, args.latent_size)
101+
point = (point / point.norm()) * (args.latent_size ** 0.5)
102+
ss_image = gen(point, depth=args.out_depth, alpha=1)
103+
# color adjust the generated image:
104+
ss_image = adjust_dynamic_range(ss_image)
105+
106+
# save the ss_image in the directory
107+
imsave(os.path.join(save_path, str(img_num) + ".png"),
108+
ss_image.squeeze(0).permute(1, 2, 0).cpu())
109+
110+
print("Generated %d images at %s" % (args.num_samples, save_path))
111+
112+
113+
if __name__ == '__main__':
114+
main(parse_arguments())

0 commit comments

Comments
 (0)