Skip to content

Commit

Permalink
Iterating over the patches in slices.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Apr 24, 2016
1 parent 346eac2 commit 31e9350
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,14 @@ def total_variation_loss(self):
# Optimization Loop
#------------------------------------------------------------------------------------------------------------------

def iterate_batches(self, *arrays, batch_size=1024):
total_size = arrays[0].shape[0]
indices = np.arange(total_size)

for index in range(0, total_size, batch_size):
excerpt = indices[index:index + batch_size]
yield excerpt, [a[excerpt] for a in arrays]

def evaluate(self, Xn):
"""Callback for the L-BFGS optimization that computes the loss and gradients on the GPU.
"""
Expand All @@ -456,20 +464,31 @@ def evaluate(self, Xn):
layer = self.model.network['nn'+l]
patches, norms_m, norms_s, history = self.style_data[l]

weights = patches.astype(np.float32)
weights[:,:-3] /= (norms_m * 3.0)
if semantic_weight: weights[:,-3:] /= (norms_s * semantic_weight)
layer.W.set_value(weights)

# Helper for normalizing an array?
nm = np.sqrt(np.sum(f[:,:-3] ** 2.0, axis=(1,), keepdims=True))
ns = np.sqrt(np.sum(f[:,-3:] ** 2.0, axis=(1,), keepdims=True))

f[:,:-3] /= (nm * 3.0) # TODO: Use exact number of channels.
if semantic_weight: f[:,-3:] /= (ns * semantic_weight)

best_idx, best_val, best_match = self.compute_matches[l](f, history)
best_idx, best_val = None, 0.0
for idx, (bp, bm, bs, bh) in self.iterate_batches(patches, norms_m, norms_s, history):
weights = bp.astype(np.float32)
weights[:,:-3] /= (bm * 3.0) # TODO: Use exact number of channels.
if semantic_weight: weights[:,-3:] /= (bs * semantic_weight)
layer.W.set_value(weights)

cur_idx, cur_val, cur_match = self.compute_matches[l](f, history[idx])
if best_idx is None:
best_idx = cur_idx
best_val = cur_val
else:
i = np.where(cur_val > best_val)
best_idx[i] = idx[cur_idx[i]]
best_val[i] = cur_val[i]

history[idx] = cur_match

history[:] = best_match
current_best.append(patches[best_idx].astype(np.float32))

grads, *losses = self.compute_grad_and_losses(current_img, self.content_map, *current_best)
Expand Down

0 comments on commit 31e9350

Please sign in to comment.