Skip to content

Commit

Permalink
Extracted slice processing and supporting extra pass for correct warm…
Browse files Browse the repository at this point in the history
…-up, helps L-BFGS converge.
  • Loading branch information
alexjc committed Apr 25, 2016
1 parent d6c5a10 commit f789337
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def prepare_style(self, scale=1.0):

self.style_data = {}
for layer, *data in zip(self.style_layers, result[0::3], result[1::3], result[2::3]):
l, patches = self.model.network['nn'+layer], data[0]
l = self.model.network['nn'+layer]
data[0] = data[0][:l.num_filters*args.slices]
patches = data[0]
l.num_filters = patches.shape[0] // args.slices
self.style_data[layer] = [d.astype(np.float16) for d in data] + [np.zeros((patches.shape[0],), dtype=np.float16)]
print(' - Style layer {}: {} patches in {:,}kb.'.format(layer, patches.shape[0], patches.size//1000))
Expand Down Expand Up @@ -450,6 +452,29 @@ def iterate_batches(self, *arrays, batch_size):
excerpt = indices[index:index + batch_size]
yield excerpt, [a[excerpt] for a in arrays]

def evaluate_slices(self, f, l, semantic_weight):
layer, data = self.model.network['nn'+l], self.style_data[l]
history = data[-1]

best_idx, best_val = None, 0.0
for idx, (bp, bm, bs, bh) in self.iterate_batches(*data, batch_size=layer.num_filters):
weights = bp.astype(np.float32)
weights[:,:-3] /= (bm * 3.0) # TODO: Use exact number of channels.
if semantic_weight: weights[:,-3:] /= (bs * semantic_weight)
layer.W.set_value(weights)

cur_idx, cur_val, cur_match = self.compute_matches[l](f, history[idx])
if best_idx is None:
best_idx = cur_idx
best_val = cur_val
else:
i = np.where(cur_val > best_val)
best_idx[i] = idx[cur_idx[i]]
best_val[i] = cur_val[i]

history[idx] = cur_match
return best_idx

def evaluate(self, Xn):
"""Callback for the L-BFGS optimization that computes the loss and gradients on the GPU.
"""
Expand All @@ -462,36 +487,19 @@ def evaluate(self, Xn):
current_best, semantic_weight = [], math.sqrt(9.0 / args.semantic_weight) if args.semantic_weight else None

for l, f in zip(self.style_layers, current_features):
layer = self.model.network['nn'+l]
patches, norms_m, norms_s, history = self.style_data[l]
patches = patches[:layer.num_filters*args.slices]

# Helper for normalizing an array?
nm = np.sqrt(np.sum(f[:,:-3] ** 2.0, axis=(1,), keepdims=True))
ns = np.sqrt(np.sum(f[:,-3:] ** 2.0, axis=(1,), keepdims=True))

f[:,:-3] /= (nm * 3.0) # TODO: Use exact number of channels.
if semantic_weight: f[:,-3:] /= (ns * semantic_weight)

best_idx, best_val = None, 0.0
for idx, (bp, bm, bs, bh) in self.iterate_batches(patches, norms_m, norms_s, history, batch_size=layer.num_filters):

weights = bp.astype(np.float32)
weights[:,:-3] /= (bm * 3.0) # TODO: Use exact number of channels.
if semantic_weight: weights[:,-3:] /= (bs * semantic_weight)
layer.W.set_value(weights)

cur_idx, cur_val, cur_match = self.compute_matches[l](f, history[idx])
if best_idx is None:
best_idx = cur_idx
best_val = cur_val
else:
i = np.where(cur_val > best_val)
best_idx[i] = idx[cur_idx[i]]
best_val[i] = cur_val[i]

history[idx] = cur_match
# Compute best matching patches this style layer, going through all slices.
warmup = bool(args.variety > 0.0 and self.iteration == 0)
for _ in range(2 if warmup else 1):
best_idx = self.evaluate_slices(f, l, semantic_weight)

patches = self.style_data[l][0]
current_best.append(patches[best_idx].astype(np.float32))

grads, *losses = self.compute_grad_and_losses(current_img, self.content_map, *current_best)
Expand Down Expand Up @@ -529,6 +537,7 @@ def evaluate(self, Xn):

# Return the data in the right format for L-BFGS.
self.frame += 1
self.iteration += 1
return loss, np.array(grads).flatten().astype(np.float64)

def run(self):
Expand Down Expand Up @@ -579,7 +588,7 @@ def run(self):
data_bounds = np.zeros((np.product(Xn.shape), 2), dtype=np.float64)
data_bounds[:] = (0.0, 255.0)

self.iter_time, interrupt = time.time(), False
self.iter_time, self.iteration, interrupt = time.time(), 0, False
try:
Xn, Vn, info = scipy.optimize.fmin_l_bfgs_b(
self.evaluate,
Expand Down

0 comments on commit f789337

Please sign in to comment.