|
| 1 | +""" live (realtime) latent space interpolations of trained models """ |
| 2 | + |
| 3 | +import argparse |
1 | 4 | import torch as th
|
| 5 | +import numpy as np |
2 | 6 | import matplotlib.pyplot as plt
|
3 | 7 | from matplotlib.animation import FuncAnimation
|
4 |
| -from pro_gan_pytorch import PRO_GAN as pg |
5 |
| - |
6 |
| -# ========================================================================== |
7 |
| -# Tweakable parameters |
8 |
| -# ========================================================================== |
9 |
| -depth = 8 |
10 |
| -num_points = 12 |
11 |
| -transition_points = 30 |
12 |
| -# ========================================================================== |
| 8 | +from pro_gan_pytorch.PRO_GAN import Generator |
| 9 | +from torchvision.utils import make_grid |
| 10 | +from math import ceil, sqrt |
| 11 | +from scipy.ndimage import gaussian_filter |
13 | 12 |
|
14 | 13 | # create the device for running the demo:
|
15 | 14 | device = th.device("cuda" if th.cuda.is_available() else "cpu")
|
16 | 15 |
|
17 |
| -# load the model for the demo |
18 |
| -gen = th.nn.DataParallel(pg.Generator(depth=9)) |
19 |
| -gen.load_state_dict(th.load("GAN_GEN_SHADOW_8.pth", map_location=str(device))) |
20 | 16 |
|
| 17 | +def parse_arguments(): |
| 18 | + """ |
| 19 | + command line arguments parser |
| 20 | + :return: args => parsed command line arguments |
| 21 | + """ |
| 22 | + parser = argparse.ArgumentParser() |
| 23 | + |
| 24 | + parser.add_argument("--generator_file", action="store", type=str, |
| 25 | + default=None, help="path to the trained generator model") |
| 26 | + |
| 27 | + parser.add_argument("--depth", action="store", type=int, |
| 28 | + default=9, help="Depth of the network") |
| 29 | + |
| 30 | + parser.add_argument("--latent_size", action="store", type=int, |
| 31 | + default=512, help="Latent size for the network") |
| 32 | + |
| 33 | + parser.add_argument("--num_points", action="store", type=int, |
| 34 | + default=12, help="Number of samples to be seen") |
| 35 | + |
| 36 | + parser.add_argument("--transition_points", action="store", type=int, |
| 37 | + default=30, |
| 38 | + help="Number of transition samples for interpolation") |
| 39 | + |
| 40 | + parser.add_argument("--smoothing", action="store", type=float, |
| 41 | + default=1.0, |
| 42 | + help="amount of transitional smoothing") |
| 43 | + |
| 44 | + args = parser.parse_args() |
| 45 | + |
| 46 | + return args |
| 47 | + |
| 48 | + |
| 49 | +def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)): |
| 50 | + """ |
| 51 | + adjust the dynamic colour range of the given input data |
| 52 | + :param data: input image data |
| 53 | + :param drange_in: original range of input |
| 54 | + :param drange_out: required range of output |
| 55 | + :return: img => colour range adjusted images |
| 56 | + """ |
| 57 | + if drange_in != drange_out: |
| 58 | + scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / ( |
| 59 | + np.float32(drange_in[1]) - np.float32(drange_in[0])) |
| 60 | + bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale) |
| 61 | + data = data * scale + bias |
| 62 | + return th.clamp(data, min=0, max=1) |
| 63 | + |
| 64 | + |
| 65 | +def get_image(gen, point, depth, alpha): |
| 66 | + """ |
| 67 | + obtain an All-resolution grid of images from the given point |
| 68 | + :param gen: the generator object |
| 69 | + :param point: random latent point for generation |
| 70 | + :param depth: value of depth for image generation (0 indexed) |
| 71 | + :param alpha: value of alpha for fade-in (between 0 and 1) |
| 72 | + :return: img => generated image |
| 73 | + """ |
| 74 | + image = gen(point, depth, alpha).detach() |
| 75 | + image = adjust_dynamic_range(image).squeeze(dim=0) |
| 76 | + return image.cpu().numpy().transpose(1, 2, 0) |
21 | 77 |
|
22 |
| -# function to generate an image given a latent_point |
23 |
| -def get_image(point): |
24 |
| - img = gen(point, depth=depth, alpha=1).detach().squeeze(0).permute(1, 2, 0) |
25 |
| - img = (img - img.min()) / (img.max() - img.min()) |
26 |
| - return img.cpu().numpy() |
27 | 78 |
|
| 79 | +def main(args): |
| 80 | + """ |
| 81 | + Main function for the script |
| 82 | + :param args: parsed command line arguments |
| 83 | + :return: None |
| 84 | + """ |
28 | 85 |
|
29 |
| -# generate the set of points: |
30 |
| -fixed_points = th.randn(num_points, 512).to(device) |
31 |
| -fixed_points = (fixed_points / fixed_points.norm(dim=1, keepdim=True)) * (512 ** 0.5) |
32 |
| -points = [] # start with an empty list |
33 |
| -for i in range(len(fixed_points) - 1): |
34 |
| - pt_1 = fixed_points[i].view(1, -1) |
35 |
| - pt_2 = fixed_points[i + 1].view(1, -1) |
36 |
| - direction = pt_2 - pt_1 |
37 |
| - for j in range(transition_points): |
38 |
| - pt = pt_1 + ((direction / transition_points) * j) |
39 |
| - pt = (pt / pt.norm()) * (512 ** 0.5) |
40 |
| - points.append(pt) |
41 |
| - # also append the final point: |
42 |
| - points.append(pt_2) |
| 86 | + # load the model for the demo |
| 87 | + gen = th.nn.DataParallel( |
| 88 | + Generator( |
| 89 | + depth=args.depth, |
| 90 | + latent_size=args.latent_size)) |
| 91 | + gen.load_state_dict(th.load(args.generator_file, map_location=str(device))) |
43 | 92 |
|
44 |
| -start_point = points[0] |
45 |
| -points = points[1:] |
| 93 | + # generate the set of points: |
| 94 | + total_frames = args.num_points * args.transition_points |
| 95 | + all_latents = th.randn(total_frames, args.latent_size).to(device) |
| 96 | + all_latents = th.from_numpy( |
| 97 | + gaussian_filter( |
| 98 | + all_latents.cpu(), |
| 99 | + [args.smoothing * args.transition_points, 0], mode="wrap")) |
| 100 | + all_latents = (all_latents / |
| 101 | + all_latents.norm(dim=-1, keepdim=True)) * sqrt(args.latent_size) |
46 | 102 |
|
47 |
| -fig, ax = plt.subplots() |
48 |
| -plt.axis("off") |
49 |
| -shower = plt.imshow(get_image(start_point)) |
| 103 | + start_point = th.unsqueeze(all_latents[0], dim=0) |
| 104 | + points = all_latents[1:] |
50 | 105 |
|
| 106 | + fig, ax = plt.subplots() |
| 107 | + plt.axis("off") |
| 108 | + shower = plt.imshow(get_image(gen, start_point, args.depth - 1, 1)) |
51 | 109 |
|
52 |
| -def init(): |
53 |
| - return shower, |
| 110 | + def init(): |
| 111 | + return shower, |
54 | 112 |
|
| 113 | + def update(point): |
| 114 | + shower.set_data(get_image(gen, th.unsqueeze(point, dim=0), args.depth - 1, 1)) |
| 115 | + return shower, |
55 | 116 |
|
56 |
| -def update(point): |
57 |
| - shower.set_data(get_image(point)) |
58 |
| - return shower, |
| 117 | + # define the animation function |
| 118 | + ani = FuncAnimation(fig, update, frames=points, |
| 119 | + init_func=init) |
| 120 | + plt.show(ani) |
59 | 121 |
|
60 | 122 |
|
61 |
| -ani = FuncAnimation(fig, update, frames=points, |
62 |
| - init_func=init, blit=False) |
63 |
| -plt.show() |
| 123 | +if __name__ == '__main__': |
| 124 | + main(parse_arguments()) |
0 commit comments