Skip to content

Commit

Permalink
It optimizes toward something but quality is gone.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Apr 23, 2016
1 parent ceb46d5 commit ab95b0f
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def __init__(self):
else:
shape = self.style_img_original.shape[:2]

self.content_map_original = np.zeros(shape+(1,))
self.content_map_original = np.zeros(shape+(3,))
args.semantic_weight = 0.0

if self.style_map_original is None:
self.style_map_original = np.zeros(self.style_img_original.shape[:2]+(1,))
self.style_map_original = np.zeros(self.style_img_original.shape[:2]+(3,))
args.semantic_weight = 0.0

if self.content_img_original is None:
Expand Down Expand Up @@ -445,27 +445,25 @@ def evaluate(self, Xn):
current_features = self.compute_features(current_img, self.content_map)

# Iterate through each of the style layers one by one, computing best matches.
current_best, semantic_weight = [], math.sqrt(9.0 / args.semantic_weight)
assert semantic_weight > 0.0
current_best, semantic_weight = [], math.sqrt(9.0 / args.semantic_weight) if args.semantic_weight else None

for l, f in zip(self.style_layers, current_features):
layer = self.model.network['nn'+l]
patches, norms_m, norms_s = self.style_data[l]

patches = patches.astype(np.float32)
patches[:,:-3] /= (norms_m * 3.0)
patches[:,-3:] /= (norms_s * semantic_weight)
weights = patches.astype(np.float32)
weights[:,:-3] /= (norms_m * 3.0)
if semantic_weight: weights[:,-3:] /= (norms_s * semantic_weight)
layer.W.set_value(patches)

nm = np.sqrt(np.sum(f[:,:-3] ** 2.0, axis=(1,), keepdims=True))
ns = np.sqrt(np.sum(f[:,-3:] ** 2.0, axis=(1,), keepdims=True))

f[:,:-3] /= (nm * 3.0) # TODO: Use exact number of channels.
f[:,-3:] /= (ns * semantic_weight)
if semantic_weight: f[:,-3:] /= (ns * semantic_weight)

best, cost = self.compute_matches[l](f)
print('best', best[:25], '\ncost', cost[:25])
current_best.append(patches[best])
current_best.append(patches[best].astype(np.float32))

grads, *losses = self.compute_grad_and_losses(current_img, self.content_map, *current_best)
if np.isnan(grads).any():
Expand Down

0 comments on commit ab95b0f

Please sign in to comment.