Skip to content

Commit 02ebbd3

Browse files
Refactored layer generators
1 parent 3f180a6 commit 02ebbd3

File tree

2 files changed

+50
-112
lines changed

2 files changed

+50
-112
lines changed

styletransfer/generate-content-layer-outputs.ipynb

Lines changed: 33 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,16 @@
1515
"mpl.rcParams['figure.figsize'] = (12,12)\n",
1616
"mpl.rcParams['axes.grid'] = False\n",
1717
"\n",
18-
"import time\n",
19-
"import IPython.display as display\n",
18+
"from util import imshow, load_img, save_img, apply_lum, match_lum\n",
19+
"from model import StyleTransferModel\n",
20+
"from losses import content_loss\n",
21+
"from train import train\n",
2022
"\n",
21-
"from util import imshow, load_img, save_img\n",
22-
"from model import StyleTransferModel, print_stats\n",
23-
"from losses import clip_0_1, content_loss\n",
23+
"# https://www.positive.news/wp-content/uploads/2019/03/feat-1800x0-c-center.jpg\n",
24+
"content_path = tf.keras.utils.get_file('forest.jpg','file:///home/jupyter/pictures/forest.jpg')\n",
2425
"\n",
25-
"# load input images\n",
26-
"content_path = tf.keras.utils.get_file('neckarfront.jpg','https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg')\n",
27-
"style_path = tf.keras.utils.get_file('starry-night.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')\n",
28-
"content_image = load_img(content_path)\n",
29-
"style_image = load_img(style_path)\n",
30-
"\n",
31-
"plt.subplot(1, 2, 1)\n",
32-
"imshow(content_image, 'Content Image')\n",
33-
"\n",
34-
"plt.subplot(1, 2, 2)\n",
35-
"imshow(style_image, 'Style Image')\n",
26+
"content_img = load_img(content_path, max_dim=512)\n",
27+
"imshow(content_img, 'Content Image')\n",
3628
"\n",
3729
"plt.show()"
3830
]
@@ -45,53 +37,40 @@
4537
"source": [
4638
"# reconstruct content, for every layer\n",
4739
"content_layers = ['block1_conv1',\n",
48-
" 'block2_conv1',\n",
49-
" 'block3_conv1', \n",
50-
" 'block4_conv1', \n",
51-
" 'block5_conv1']\n",
52-
"\n",
53-
"opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)\n",
40+
" 'block2_conv1',\n",
41+
" 'block3_conv1', \n",
42+
" 'block4_conv1', \n",
43+
" 'block5_conv1']\n",
5444
"\n",
5545
"for content_layer in content_layers:\n",
5646
" extractor = StyleTransferModel(['block1_conv1'], [content_layer])\n",
57-
" results = extractor(tf.constant(content_image))\n",
5847
"\n",
59-
" # the variable to optimize\n",
60-
" image = tf.Variable(tf.random.uniform(content_image.shape))\n",
48+
" content_targets = extractor(content_img)['content']\n",
49+
" content_weights = [1.0]\n",
50+
"\n",
51+
" # initialize the gradients with random noise\n",
52+
" initial_gradients = tf.Variable(tf.random.uniform(content_img.shape))\n",
6153
"\n",
62-
" content_targets = extractor(content_image)['content']\n",
63-
" content_weights = tf.constant([ 1e10 ])\n",
64-
" \n",
65-
" @tf.function()\n",
66-
" def train_step(image):\n",
67-
" with tf.GradientTape() as tape:\n",
54+
" def loss_func(image):\n",
6855
" outputs = extractor(image)\n",
6956
" loss = content_loss(outputs['content'], content_targets, content_weights)\n",
57+
" return loss\n",
7058
"\n",
71-
" grad = tape.gradient(loss, image)\n",
72-
" opt.apply_gradients([(grad, image)])\n",
73-
" image.assign(clip_0_1(image))\n",
74-
"\n",
75-
" start = time.time()\n",
76-
"\n",
77-
" epochs = 20\n",
78-
" steps_per_epoch = 100\n",
79-
"\n",
80-
" step = 0\n",
81-
" for n in range(epochs):\n",
82-
" for m in range(steps_per_epoch):\n",
83-
" step += 1\n",
84-
" train_step(image)\n",
85-
" print(\".\", end='')\n",
86-
" display.clear_output(wait=True)\n",
87-
" imshow(image.read_value())\n",
88-
" plt.title(\"Train step: {}\".format(step))\n",
89-
" plt.show()\n",
90-
"\n",
91-
" end = time.time()\n",
92-
" print(\"Total time: {:.1f}\".format(end-start))\n",
59+
" result = train(loss_func, extractor, initial_gradients, epochs=20)\n",
9360
"\n",
94-
" save_img(image[0], 'content_{}.png'.format(content_layer))"
61+
" save_img(result[0], 'content_{}.png'.format(content_layer))"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {},
68+
"outputs": [],
69+
"source": [
70+
"for i, content_layer in enumerate(content_layers):\n",
71+
" plt.subplot(3, 2, i+1)\n",
72+
" img = load_img('output/content_{}.png'.format(content_layer), max_dim=512)\n",
73+
" imshow(result, 'Content Layer {}'.format(content_layer))"
9574
]
9675
},
9776
{

styletransfer/generate-style-layer-outputs.ipynb

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,16 @@
1515
"mpl.rcParams['figure.figsize'] = (12,12)\n",
1616
"mpl.rcParams['axes.grid'] = False\n",
1717
"\n",
18-
"import time\n",
19-
"import IPython.display as display\n",
18+
"from util import imshow, load_img, save_img, apply_lum, match_lum\n",
19+
"from model import StyleTransferModel\n",
20+
"from losses import style_loss\n",
21+
"from train import train\n",
2022
"\n",
21-
"from util import imshow, load_img, save_img\n",
22-
"from model import StyleTransferModel, print_stats\n",
23-
"from losses import clip_0_1, style_loss\n",
23+
"# https://images1.novica.net/pictures/10/p348189_2a.jpg, https://www.novica.com/p/impressionist-painting-in-delod-pangkung/348189/\n",
24+
"style_path = tf.keras.utils.get_file('impressionist-bali.jpg','file:///home/jupyter/pictures/impressionist-bali.jpg')\n",
2425
"\n",
25-
"# load input images\n",
26-
"content_path = tf.keras.utils.get_file('neckarfront.jpg','https://upload.wikimedia.org/wikipedia/commons/0/00/Tuebingen_Neckarfront.jpg')\n",
27-
"style_path = tf.keras.utils.get_file('starry-night.jpg','https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1280px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg')\n",
28-
"content_image = load_img(content_path)\n",
29-
"style_image = load_img(style_path)\n",
30-
"\n",
31-
"plt.subplot(1, 2, 1)\n",
32-
"imshow(content_image, 'Content Image')\n",
33-
"\n",
34-
"plt.subplot(1, 2, 2)\n",
35-
"imshow(style_image, 'Style Image')\n",
26+
"style_img = load_img(style_path, max_dim=512)\n",
27+
"imshow(style_img, 'Style Image')\n",
3628
"\n",
3729
"plt.show()"
3830
]
@@ -51,58 +43,25 @@
5143
" 'block5_conv1']\n",
5244
"\n",
5345
"\n",
54-
"opt = tf.optimizers.Adam(learning_rate=0.02, beta_1=0.99, epsilon=1e-1)\n",
55-
"\n",
5646
"for idx in range(len(style_layers)):\n",
5747
" extractor = StyleTransferModel(style_layers[:idx+1], ['block1_conv1'])\n",
58-
" results = extractor(tf.constant(style_image))\n",
5948
"\n",
60-
" # the variable to optimize\n",
61-
" image = tf.Variable(tf.random.uniform(style_image.shape))\n",
62-
"\n",
63-
" style_targets = extractor(style_image)['style']\n",
64-
"\n",
65-
" # style_weights = [ 1e3/n**2 for n in [64, 128, 256, 512, 512] ]\n",
49+
" style_targets = extractor(style_img)['style']\n",
6650
" style_weights = [ 1.0, 1.0, 1.0, 1.0, 1.0 ]\n",
67-
"\n",
6851
" style_weights = style_weights[:idx+1]\n",
52+
" style_weights = [ w/sum(style_weights) for w in style_weights ] # normalize weights\n",
6953
"\n",
70-
" # the weights are normalized\n",
71-
" style_weights = [ w/sum(style_weights) for w in style_weights ]\n",
72-
" style_weights = tf.constant(style_weights)\n",
54+
" # initialize the gradients with random noise\n",
55+
" initial_gradients = tf.Variable(tf.random.uniform(style_img.shape))\n",
7356
"\n",
74-
"\n",
75-
" @tf.function()\n",
76-
" def train_step(image):\n",
77-
" with tf.GradientTape() as tape:\n",
57+
" def loss_func(image):\n",
7858
" outputs = extractor(image)\n",
79-
" total_loss = style_loss(outputs['style'], style_targets, style_weights)\n",
80-
"\n",
81-
" grad = tape.gradient(total_loss, image)\n",
82-
" opt.apply_gradients([(grad, image)])\n",
83-
" image.assign(clip_0_1(image))\n",
84-
"\n",
85-
" start = time.time()\n",
86-
"\n",
87-
" epochs = 20\n",
88-
" steps_per_epoch = 100\n",
89-
"\n",
90-
" step = 0\n",
91-
" for n in range(epochs):\n",
92-
" for m in range(steps_per_epoch):\n",
93-
" step += 1\n",
94-
" train_step(image)\n",
95-
" print(\".\", end='')\n",
96-
" display.clear_output(wait=True)\n",
97-
" imshow(image.read_value())\n",
98-
" plt.title(\"Train step: {}\".format(step))\n",
99-
" print(style_layers[:idx+1])\n",
100-
" plt.show()\n",
59+
" loss = style_loss(outputs['style'], style_targets, style_weights)\n",
60+
" return loss\n",
10161
"\n",
102-
" end = time.time()\n",
103-
" print(\"Total time: {:.1f}\".format(end-start))\n",
62+
" result = train(loss_func, extractor, initial_gradients, epochs=20)\n",
10463
"\n",
105-
" save_img(image[0], 'style_{}.png'.format(style_layers[idx]))"
64+
" save_img(result[0], 'style_{}.png'.format(style_layers[idx]))"
10665
]
10766
},
10867
{

0 commit comments

Comments
 (0)