Skip to content

Commit

Permalink
Improving new matching code so it works more precisely according to t…
Browse files Browse the repository at this point in the history
…he semantic map by default.
  • Loading branch information
alexjc committed Mar 25, 2016
1 parent b151858 commit 5bb9261
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
add_arg('--content-weight', default=10.0, type=float, help='Weight of content relative to style.')
add_arg('--content-layers', default='4_2', type=str, help='The layer with which to match content.')
add_arg('--style', required=True, type=str, help='Style image path to extract patches.')
add_arg('--style-weight', default=50.0, type=float, help='Weight of style relative to content.')
add_arg('--style-weight', default=25.0, type=float, help='Weight of style relative to content.')
add_arg('--style-layers', default='3_1,4_1', type=str, help='The layers to match style patches.')
add_arg('--semantic-ext', default='_sem.png', type=str, help='File extension for the semantic maps.')
add_arg('--semantic-weight', default=1.0, type=float, help='Global weight of semantics vs. features.')
add_arg('--semantic-weight', default=10.0, type=float, help='Global weight of semantics vs. features.')
add_arg('--output', default='output.png', type=str, help='Output image path to save once done.')
add_arg('--resolutions', default=3, type=int, help='Number of image scales to process.')
add_arg('--smoothness', default=1E+0, type=float, help='Weight of image smoothing scheme.')
Expand Down Expand Up @@ -124,16 +124,20 @@ def setup_model(self):

# Second network for the semantic layers. This dynamically downsamples the map and concatenates it.
net['map'] = InputLayer((1, 3, None, None))
net['map1_1'] = PoolLayer(net['map'], 2, mode='average_exc_pad')
net['map2_1'] = PoolLayer(net['map'], 2, mode='average_exc_pad')
net['map3_1'] = PoolLayer(net['map'], 4, mode='average_exc_pad')
net['map4_1'] = PoolLayer(net['map'], 8, mode='average_exc_pad')

# Third network for the nearest neighbors; it's a default size for now, updated once we know more.
net['nn1_1'] = ConvLayer(net['conv1_1'], 1, 3, b=None, pad=0)
net['nn2_1'] = ConvLayer(net['conv2_1'], 1, 3, b=None, pad=0)
net['mm2_1'] = ConvLayer(net['map2_1'], 1, 3, b=None, pad=0)
net['nn3_1'] = ConvLayer(net['conv3_1'], 1, 3, b=None, pad=0)
net['mm3_1'] = ConvLayer(net['map3_1'], 1, 3, b=None, pad=0)
net['nn4_1'] = ConvLayer(net['conv4_1'], 1, 3, b=None, pad=0)

net['mm1_1'] = ConvLayer(net['map1_1'], 1, 3, b=None, pad=0)
net['mm2_1'] = ConvLayer(net['map2_1'], 1, 3, b=None, pad=0)
net['mm3_1'] = ConvLayer(net['map3_1'], 1, 3, b=None, pad=0)
net['mm4_1'] = ConvLayer(net['map4_1'], 1, 3, b=None, pad=0)

self.network = net
Expand Down Expand Up @@ -233,11 +237,11 @@ def __init__(self):

if self.content_map_original is None:
self.content_map_original = np.zeros(self.content_img_original.shape[:2]+(1,))
self.semantic_weight = 0.0
args.semantic_weight = 0.0

if self.style_map_original is None:
self.style_map_original = np.zeros(self.style_img_original.shape[:2]+(1,))
self.semantic_weight = 0.0
args.semantic_weight = 0.0

if self.content_img_original is None:
self.content_img_original = np.zeros(self.content_map_original.shape[:2]+(3,))
Expand Down Expand Up @@ -277,7 +281,7 @@ def prepare_content(self, scale=1.0):
content_image = skimage.transform.rescale(self.content_img_original, scale) * 255.0
self.content_image = self.model.prepare_image(content_image)

content_map = skimage.transform.rescale(self.content_map_original * args.semantic_weight, scale) * 255.0
content_map = skimage.transform.rescale(self.content_map_original, scale) * 255.0
self.content_map = content_map.transpose((2, 0, 1))[np.newaxis].astype(np.float32)

def prepare_style(self, scale=1.0):
Expand All @@ -287,7 +291,7 @@ def prepare_style(self, scale=1.0):
style_image = skimage.transform.rescale(self.style_img_original, scale) * 255.0
self.style_image = self.model.prepare_image(style_image)

style_map = skimage.transform.rescale(self.style_map_original * args.semantic_weight, scale) * 255.0
style_map = skimage.transform.rescale(self.style_map_original, scale) * 255.0
self.style_map = style_map.transpose((2, 0, 1))[np.newaxis].astype(np.float32)

# Workaround for Issue #8. Not clear what this is caused by, NaN seems to happen in convolution node
Expand Down Expand Up @@ -390,21 +394,25 @@ def style_loss(self):
return style_loss

# Extract the patches from the current image, as well as their magnitude.
result = self.extract_patches([self.model.tensor_outputs['conv'+l] for l in self.style_layers])
result = self.extract_patches([self.model.tensor_outputs['conv'+l] for l in self.style_layers]
+ [self.model.tensor_outputs['map'+l] for l in self.style_layers])

result_nn = result[:len(self.style_layers)*2]
result_mm = result[len(self.style_layers)*2:]
# Multiple style layers are optimized separately, usually sem3_1 and sem4_1.
for l, patches, norms in zip(self.style_layers, result[::2], result[1::2]):
for l, patches, norms, sem_norms in zip(self.style_layers, result_nn[::2], result_nn[1::2], result_mm[1::2]):
# Compute the result of the normalized cross-correlation, using results from the nearest-neighbor
# layers called 'nn3_1' and 'nn4_1'.
# layers called 'nn3_1' and 'nn4_1' (for example).
layer = self.model.network['nn'+l]
dist = self.model.tensor_outputs['nn'+l]
dist = dist.reshape((dist.shape[1], -1)) / norms.reshape((1,-1)) / layer.N.reshape((-1,1))

sem_layer = self.model.network['mm'+l]
sem = self.model.tensor_outputs['mm'+l]
sem = sem.reshape((sem.shape[1], -1))
sem = sem.reshape((sem.shape[1], -1)) / sem_norms.reshape((1,-1)) / sem_layer.N.reshape((-1,1))

# Pick the best style patches for each patch in the current image, the result is an array of indices.
best = (dist * sem).argmax(axis=0)
best = (dist + args.semantic_weight * sem).argmax(axis=0)

# Compute the mean squared error between the current patch and the best matching style patch.
# Ignore the last channels (from semantic map) so errors returned are indicative of image only.
Expand Down Expand Up @@ -487,6 +495,7 @@ def run(self):

# Now setup the model with the new data, ready for the optimization loop.
self.model.setup(layers=['conv'+l for l in self.style_layers] +
['map'+l for l in self.style_layers] +
['nn'+l for l in self.style_layers] +
['mm'+l for l in self.style_layers] +
['conv'+l for l in self.content_layers])
Expand Down

0 comments on commit 5bb9261

Please sign in to comment.