Skip to content

Commit

Permalink
Fixes for semantic weight calculation.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Apr 23, 2016
1 parent 67cf246 commit ceb46d5
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,25 +445,26 @@ def evaluate(self, Xn):
current_features = self.compute_features(current_img, self.content_map)

# Iterate through each of the style layers one by one, computing best matches.
current_best, semantic_weight = [], math.sqrt(args.semantic_weight)
current_best, semantic_weight = [], math.sqrt(9.0 / args.semantic_weight)
assert semantic_weight > 0.0

for l, f in zip(self.style_layers, current_features):
layer = self.model.network['nn'+l]
patches, norms_m, norms_s = self.style_data[l]

patches = patches.astype(np.float32)
patches[:,:-3] /= (3.0 * norms_m.astype(np.float32))
patches[:,-3:] /= (3.0 * norms_s.astype(np.float32) * semantic_weight)
patches[:,:-3] /= (norms_m * 3.0)
patches[:,-3:] /= (norms_s * semantic_weight)
layer.W.set_value(patches)

nm = np.sqrt(np.sum(f[:,:-3] ** 2.0, axis=(1,), keepdims=True))
ns = np.sqrt(np.sum(f[:,-3:] ** 2.0, axis=(1,), keepdims=True))

f[:,:-3] /= (3.0 * nm) # TODO: Use exact number of channels.
f[:,-3:] /= (3.0 * ns * semantic_weight)
f[:,:-3] /= (nm * 3.0) # TODO: Use exact number of channels.
f[:,-3:] /= (ns * semantic_weight)

best, cost = self.compute_matches[l](f)
print('best', best[:25], '\ncost', cost[:25])
current_best.append(patches[best])

grads, *losses = self.compute_grad_and_losses(current_img, self.content_map, *current_best)
Expand Down

0 comments on commit ceb46d5

Please sign in to comment.