|
15 | 15 | "mpl.rcParams['figure.figsize'] = (12,12)\n",
|
16 | 16 | "mpl.rcParams['axes.grid'] = False\n",
|
17 | 17 | "\n",
|
18 |
| - "import time\n", |
19 |
| - "import IPython.display as display\n", |
| 18 | + "from util import imshow, load_img, save_img, apply_lum, match_lum\n", |
| 19 | + "from model import StyleTransferModel\n", |
| 20 | + "from losses import style_loss\n", |
| 21 | + "from train import train\n", |
20 | 22 | "\n",
|
21 |
| - "from util import imshow, load_img, save_img\n", |
22 |
| - "from model import StyleTransferModel, print_stats\n", |
23 |
| - "from losses import clip_0_1, style_loss\n", |
| 23 | + "# https://images1.novica.net/pictures/10/p348189_2a.jpg, https://www.novica.com/p/impressionist-painting-in-delod-pangkung/348189/\n", |
| 24 | + "style_path = tf.keras.utils.get_file('impressionist-bali.jpg','file:///home/jupyter/pictures/impressionist-bali.jpg')\n", |
24 | 25 | "\n",
|
25 |
| - "# load input images\n", |
26 |
| - "content_path = tf.keras.utils.get_file('neckarfront.jpg','https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg')\n", |
27 |
| - "style_path = tf.keras.utils.get_file('starry-night.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')\n", |
28 |
| - "content_image = load_img(content_path)\n", |
29 |
| - "style_image = load_img(style_path)\n", |
30 |
| - "\n", |
31 |
| - "plt.subplot(1, 2, 1)\n", |
32 |
| - "imshow(content_image, 'Content Image')\n", |
33 |
| - "\n", |
34 |
| - "plt.subplot(1, 2, 2)\n", |
35 |
| - "imshow(style_image, 'Style Image')\n", |
| 26 | + "style_img = load_img(style_path, max_dim=512)\n", |
| 27 | + "imshow(style_img, 'Style Image')\n", |
36 | 28 | "\n",
|
37 | 29 | "plt.show()"
|
38 | 30 | ]
|
|
51 | 43 | " 'block5_conv1']\n",
|
52 | 44 | "\n",
|
53 | 45 | "\n",
|
54 |
| - "opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)\n", |
55 |
| - "\n", |
56 | 46 | "for idx in range(len(style_layers)):\n",
|
57 | 47 | " extractor = StyleTransferModel(style_layers[:idx+1], ['block1_conv1'])\n",
|
58 |
| - " results = extractor(tf.constant(style_image))\n", |
59 | 48 | "\n",
|
60 |
| - " # the variable to optimize\n", |
61 |
| - " image = tf.Variable(tf.random.uniform(style_image.shape))\n", |
62 |
| - "\n", |
63 |
| - " style_targets = extractor(style_image)['style']\n", |
64 |
| - "\n", |
65 |
| - " # style_weights = [ 1e3/n**2 for n in [64, 128, 256, 512, 512] ]\n", |
| 49 | + " style_targets = extractor(style_img)['style']\n", |
66 | 50 | " style_weights = [ 1.0, 1.0, 1.0, 1.0, 1.0 ]\n",
|
67 |
| - "\n", |
68 | 51 | " style_weights = style_weights[:idx+1]\n",
|
| 52 | + " style_weights = [ w/sum(style_weights) for w in style_weights ] # normalize weights\n", |
69 | 53 | "\n",
|
70 |
| - " # the weights are normalized\n", |
71 |
| - " style_weights = [ w/sum(style_weights) for w in style_weights ]\n", |
72 |
| - " style_weights = tf.constant(style_weights)\n", |
| 54 | + " # initialize the gradients with random noise\n", |
| 55 | + " initial_gradients = tf.Variable(tf.random.uniform(style_img.shape))\n", |
73 | 56 | "\n",
|
74 |
| - "\n", |
75 |
| - " @tf.function()\n", |
76 |
| - " def train_step(image):\n", |
77 |
| - " with tf.GradientTape() as tape:\n", |
| 57 | + " def loss_func(image):\n", |
78 | 58 | " outputs = extractor(image)\n",
|
79 |
| - " total_loss = style_loss(outputs['style'], style_targets, style_weights)\n", |
80 |
| - "\n", |
81 |
| - " grad = tape.gradient(total_loss, image)\n", |
82 |
| - " opt.apply_gradients([(grad, image)])\n", |
83 |
| - " image.assign(clip_0_1(image))\n", |
84 |
| - "\n", |
85 |
| - " start = time.time()\n", |
86 |
| - "\n", |
87 |
| - " epochs = 20\n", |
88 |
| - " steps_per_epoch = 100\n", |
89 |
| - "\n", |
90 |
| - " step = 0\n", |
91 |
| - " for n in range(epochs):\n", |
92 |
| - " for m in range(steps_per_epoch):\n", |
93 |
| - " step += 1\n", |
94 |
| - " train_step(image)\n", |
95 |
| - " print(\".\", end='')\n", |
96 |
| - " display.clear_output(wait=True)\n", |
97 |
| - " imshow(image.read_value())\n", |
98 |
| - " plt.title(\"Train step: {}\".format(step))\n", |
99 |
| - " print(style_layers[:idx+1])\n", |
100 |
| - " plt.show()\n", |
| 59 | + " loss = style_loss(outputs['style'], style_targets, style_weights)\n", |
| 60 | + " return loss\n", |
101 | 61 | "\n",
|
102 |
| - " end = time.time()\n", |
103 |
| - " print(\"Total time: {:.1f}\".format(end-start))\n", |
| 62 | + " result = train(loss_func, extractor, initial_gradients, epochs=20)\n", |
104 | 63 | "\n",
|
105 |
| - " save_img(image[0], 'style_{}.png'.format(style_layers[idx]))" |
| 64 | + " save_img(result[0], 'style_{}.png'.format(style_layers[idx]))" |
106 | 65 | ]
|
107 | 66 | },
|
108 | 67 | {
|
|
0 commit comments