Skip to content

Commit 14a32c8

Browse files
committed
more demo stuff added
1 parent be59f5d commit 14a32c8

File tree

4 files changed

+143
-3
lines changed

4 files changed

+143
-3
lines changed

samples/.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,9 @@
22
pro-gan_training_video_smaller.mp4
33

44
# also ignore the trained model weights
5-
GAN_GEN_SHADOW_8.pth
5+
GAN_GEN_SHADOW_8.pth
6+
7+
# ignore some huge videos:
8+
interpolation.mp4
9+
video_2.gif
10+
video_3.gif

samples/demo.gif

1.65 MB
Loading

samples/demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from matplotlib.animation import FuncAnimation
44
from pro_gan_pytorch import PRO_GAN as pg
55

6-
th.manual_seed(1)
7-
86
# ==========================================================================
97
# Tweakable parameters
108
# ==========================================================================
@@ -30,13 +28,15 @@ def get_image(point):
3028

3129
# generate the set of points:
3230
fixed_points = th.randn(num_points, 512).to(device)
31+
fixed_points = (fixed_points / fixed_points.norm(dim=1, keepdim=True)) * (512 ** 0.5)
3332
points = [] # start with an empty list
3433
for i in range(len(fixed_points) - 1):
3534
pt_1 = fixed_points[i].view(1, -1)
3635
pt_2 = fixed_points[i + 1].view(1, -1)
3736
direction = pt_2 - pt_1
3837
for j in range(transition_points):
3938
pt = pt_1 + ((direction / transition_points) * j)
39+
pt = (pt / pt.norm()) * (512 ** 0.5)
4040
points.append(pt)
4141
# also append the final point:
4242
points.append(pt_2)

samples/latent_sapce_interpolation.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
""" script for generating samples from a trained model """
2+
3+
import argparse
4+
import os
5+
6+
import torch as th
7+
8+
# define the device for the training script
9+
device = th.device("cuda" if th.cuda.is_available() else "cpu")
10+
11+
# set manual seed to 3
12+
th.manual_seed(3)
13+
14+
15+
def parse_arguments():
16+
"""
17+
command line arguments parser
18+
:return: args => parsed command line arguments
19+
"""
20+
parser = argparse.ArgumentParser()
21+
22+
parser.add_argument("--generator_file", action="store", type=str,
23+
help="pretrained weights file for generator", required=True)
24+
25+
parser.add_argument("--latent_size", action="store", type=int,
26+
default=512,
27+
help="latent size for the generator")
28+
29+
parser.add_argument("--depth", action="store", type=int,
30+
default=9, help="Value of depth for image generation (Resolution)")
31+
32+
parser.add_argument("--alpha", action="store", type=float,
33+
default=1, help="Value of alpha (Fade-in factor)")
34+
35+
parser.add_argument("--num_samples", action="store", type=int,
36+
default=64,
37+
help="number of samples in the sheet (preferably a square number)")
38+
39+
parser.add_argument("--time", action="store", type=float,
40+
default=1,
41+
help="Number of minutes for the video to make")
42+
43+
parser.add_argument("--traversal_time", action="store", type=float,
44+
default=3,
45+
help="Number of seconds to go from one point to another")
46+
47+
parser.add_argument("--static_time", action="store", type=float,
48+
default=1,
49+
help="Number of seconds to display a sample")
50+
51+
parser.add_argument("--fps", action="store", type=int,
52+
default=30, help="Frames per second in the video")
53+
54+
parser.add_argument("--out_dir", action="store", type=str,
55+
default="interp_animation_frames/",
56+
help="path to the output directory for the frames")
57+
58+
args = parser.parse_args()
59+
60+
return args
61+
62+
63+
def main(args):
64+
"""
65+
Main function of the script
66+
:param args: parsed commandline arguments
67+
:return: None
68+
"""
69+
from pro_gan_pytorch.PRO_GAN import Generator, ProGAN
70+
71+
# create generator object:
72+
print("Creating a generator object ...")
73+
generator = th.nn.DataParallel(
74+
Generator(depth=args.depth,
75+
latent_size=args.latent_size).to(device))
76+
77+
# load the trained generator weights
78+
print("loading the trained generator weights ...")
79+
generator.load_state_dict(th.load(args.generator_file))
80+
81+
# total_frames in the video:
82+
total_time_for_one_transition = args.traversal_time + args.static_time
83+
total_frames_for_one_transition = (total_time_for_one_transition * args.fps)
84+
number_of_transitions = int((args.time * 60) / total_time_for_one_transition)
85+
total_frames = int(number_of_transitions * total_frames_for_one_transition)
86+
87+
# Let's create the animation video from the latent space interpolation
88+
# I save the frames required for making the video here
89+
points_1 = th.randn(args.num_samples, args.latent_size).to(device)
90+
points_1 = (points_1 / points_1.norm(dim=1, keepdim=True)) * (args.latent_size ** 0.5)
91+
92+
# create output directory
93+
os.makedirs(args.out_dir, exist_ok=True)
94+
95+
# Run the main loop for the interpolation:
96+
global_frame_counter = 1 # counts number of frames
97+
while global_frame_counter <= total_frames:
98+
points_2 = th.randn(args.num_samples, args.latent_size).to(device)
99+
points_2 = (points_2 / points_2.norm(dim=1, keepdim=True)) * (args.latent_size ** 0.5)
100+
direction = points_2 - points_1
101+
102+
# create the points for images in this space:
103+
number_of_points = int(args.traversal_time * args.fps)
104+
for i in range(number_of_points):
105+
points = points_1 + ((direction / number_of_points) * i)
106+
points = (points / points.norm(dim=1, keepdim=True)) * (args.latent_size ** 0.5)
107+
108+
# generate the image for this point:
109+
img = generator(points, depth=4, alpha=args.alpha)
110+
111+
# save the image:
112+
ProGAN.create_grid(img, 1, os.path.join(args.out_dir, str(global_frame_counter) + ".png"))
113+
114+
# increment the counter:
115+
global_frame_counter += 1
116+
117+
# at point_2, now add static frames:
118+
img = generator(points_2, depth=4, alpha=args.alpha)
119+
120+
# now save the same image a number of times:
121+
for _ in range(int(args.static_time * args.fps)):
122+
ProGAN.create_grid(img, 1, os.path.join(args.out_dir, str(global_frame_counter) + ".png"))
123+
global_frame_counter += 1
124+
125+
# set the point_1 := point_2
126+
points_1 = points_2
127+
128+
print("Generated %d frames ..." % global_frame_counter)
129+
130+
# video frames have been generated
131+
print("Video frames have been generated at:", args.out_dir)
132+
133+
134+
if __name__ == "__main__":
135+
main(parse_arguments())

0 commit comments

Comments
 (0)