Skip to content

Commit

Permalink
Prototype for adding more variety in patch selection, decreases cost …
Browse files Browse the repository at this point in the history
…of all patches based on how well they match already—leveling the playing field. Also disabling certain codepaths if not required.
  • Loading branch information
alexjc committed Apr 21, 2016
1 parent 4b3ad6d commit 57fe15e
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
add_arg('--content', default=None, type=str, help='Content image path as optimization target.')
add_arg('--content-weight', default=10.0, type=float, help='Weight of content relative to style.')
add_arg('--content-layers', default='4_2', type=str, help='The layer with which to match content.')
add_arg('--style', required=True, type=str, help='Style image path to extract patches.')
add_arg('--style', default=None, type=str, help='Style image path to extract patches.')
add_arg('--style-weight', default=25.0, type=float, help='Weight of style relative to content.')
add_arg('--style-layers', default='3_1,4_1', type=str, help='The layers to match style patches.')
add_arg('--semantic-ext', default='_sem.png', type=str, help='File extension for the semantic maps.')
Expand All @@ -31,6 +31,7 @@
add_arg('--output-size', default=None, type=str, help='Size of the output image, e.g. 512x512.')
add_arg('--phases', default=3, type=int, help='Number of image scales to process in phases.')
add_arg('--smoothness', default=1E+0, type=float, help='Weight of image smoothing scheme.')
add_arg('--variety', default=0.0, type=float, help='Bias toward more diverse patch selection.')
add_arg('--seed', default='noise', type=str, help='Seed image path, "noise" or "content".')
add_arg('--seed-range', default='16:240', type=str, help='Random colors chosen in range, e.g. 0:255.')
add_arg('--iterations', default=100, type=int, help='Number of iterations to run each resolution.')
Expand Down Expand Up @@ -323,13 +324,15 @@ def build(layer, prefix, name, patches, norms):
l.num_filters = patches.shape[0]
print(' - {} layer {}: {} patches in {:,}kb.'.format(name, layer, patches.shape[0], patches.size//1000))

result_nn = result[:len(self.style_layers)*2]
for layer, *data in zip(self.style_layers, result_nn[::2], result_nn[1::2]):
build(layer, 'nn', 'Style', *data)
if args.style_weight > 0.0:
result_nn = result[:len(self.style_layers)*2]
for layer, *data in zip(self.style_layers, result_nn[::2], result_nn[1::2]):
build(layer, 'nn', 'Style', *data)

result_mm = result[len(self.style_layers)*2:]
for layer, *data in zip(self.style_layers, result_mm[::2], result_mm[1::2]):
build(layer, 'mm', 'Semantic', *data)
if args.semantic_weight > 0.0:
result_mm = result[len(self.style_layers)*2:]
for layer, *data in zip(self.style_layers, result_mm[::2], result_mm[1::2]):
build(layer, 'mm', 'Semantic', *data)


def extract_patches(self, layers, size=3, stride=1):
Expand Down Expand Up @@ -412,16 +415,21 @@ def style_loss(self):
dist = self.model.tensor_outputs['nn'+l]
dist = dist.reshape((dist.shape[1], -1)) / norms.reshape((1,-1)) / layer.N.reshape((-1,1))

sem_layer = self.model.network['mm'+l]
sem = self.model.tensor_outputs['mm'+l]
sem = sem.reshape((sem.shape[1], -1)) / sem_norms.reshape((1,-1)) / sem_layer.N.reshape((-1,1))
if args.semantic_weight:
sem_layer = self.model.network['mm'+l]
sem = self.model.tensor_outputs['mm'+l]
sem = sem.reshape((sem.shape[1], -1)) / sem_norms.reshape((1,-1)) / sem_layer.N.reshape((-1,1))
else:
sem = 1.0

# Pick the best style patches for each patch in the current image, the result is an array of indices.
best = (dist + args.semantic_weight * sem).argmax(axis=0)
scores = dist + args.semantic_weight * sem
offset = scores.max(axis=1).reshape((-1,1)) if args.variety else 0.0
matches = (scores - offset * args.variety).argmax(axis=0)

# Compute the mean squared error between the current patch and the best matching style patch.
# Ignore the last channels (from semantic map) so errors returned are indicative of image only.
loss = T.mean((patches - layer.W[best]) ** 2.0)
loss = T.mean((patches - layer.W[matches]) ** 2.0)
style_loss.append(('style', l, args.style_weight * loss))

return style_loss
Expand Down

0 comments on commit 57fe15e

Please sign in to comment.