Skip to content

Commit

Permalink
Scales up to 1024x1024 as long as you slice!
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjc committed Apr 24, 2016
1 parent 31e9350 commit 9808696
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions doodle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
add_arg('--output', default='output.png', type=str, help='Output image path to save once done.')
add_arg('--output-size', default=None, type=str, help='Size of the output image, e.g. 512x512.')
add_arg('--phases', default=3, type=int, help='Number of image scales to process in phases.')
add_arg('--slices', default=2, type=int, help='Split patches up into this number of batches.')
add_arg('--smoothness', default=1E+0, type=float, help='Weight of image smoothing scheme.')
add_arg('--variety', default=0.0, type=float, help='Bias toward selecting diverse patches, e.g. 0.5.')
add_arg('--seed', default='noise', type=str, help='Seed image path, "noise" or "content".')
Expand Down Expand Up @@ -71,7 +72,7 @@ def error(message, *lines):
# Load the underlying deep learning libraries based on the device specified. If you specify THEANO_FLAGS manually,
# the code assumes you know what you are doing and they are not overriden!
os.environ.setdefault('THEANO_FLAGS', 'floatX=float32,device={},force_device=True,'\
'print_active_device=False,optimizer=fast_compile'.format(args.device))
'print_active_device=False'.format(args.device))

# Scientific Libraries
import numpy as np
Expand Down Expand Up @@ -313,7 +314,7 @@ def prepare_style(self, scale=1.0):
self.style_data = {}
for layer, *data in zip(self.style_layers, result[0::3], result[1::3], result[2::3]):
l, patches = self.model.network['nn'+layer], data[0]
l.num_filters = patches.shape[0] # TODO: This is the number of slices.
l.num_filters = patches.shape[0] // args.slices
self.style_data[layer] = [d.astype(np.float16) for d in data] + [np.zeros((patches.shape[0],), dtype=np.float16)]
print(' - Style layer {}: {} patches in {:,}kb.'.format(layer, patches.shape[0], patches.size//1000))

Expand Down Expand Up @@ -354,7 +355,7 @@ def prepare_optimization(self):
nn_layers = [self.model.network['nn'+l] for l in self.style_layers]
self.matcher_outputs = dict(zip(self.style_layers, lasagne.layers.get_output(nn_layers, self.matcher_inputs)))

self.compute_matches = {l: theano.function([self.matcher_tensors[l], self.matcher_history[l]],
self.compute_matches = {l: theano.function([self.matcher_tensors[l]], # self.matcher_history[l]],
self.do_match_patches(l)) for l in self.style_layers}

self.tensor_matches = [T.tensor4() for l in self.style_layers]
Expand All @@ -373,8 +374,8 @@ def do_match_patches(self, layer):
dist = self.matcher_outputs[layer]
dist = dist.reshape((dist.shape[1], -1))

offset = self.matcher_history[layer].reshape((-1, 1))
scores = (dist - offset * args.variety)
# offset = self.matcher_history[layer].reshape((-1, 1))
scores = dist # (dist - offset * args.variety)
matches = scores.argmax(axis=0)

# Pick the best style patches for each patch in the current image, the result is an array of indices.
Expand Down Expand Up @@ -441,7 +442,7 @@ def total_variation_loss(self):
# Optimization Loop
#------------------------------------------------------------------------------------------------------------------

def iterate_batches(self, *arrays, batch_size=1024):
def iterate_batches(self, *arrays, batch_size):
total_size = arrays[0].shape[0]
indices = np.arange(total_size)

Expand All @@ -463,6 +464,7 @@ def evaluate(self, Xn):
for l, f in zip(self.style_layers, current_features):
layer = self.model.network['nn'+l]
patches, norms_m, norms_s, history = self.style_data[l]
patches = patches[:layer.num_filters*args.slices]

# Helper for normalizing an array?
nm = np.sqrt(np.sum(f[:,:-3] ** 2.0, axis=(1,), keepdims=True))
Expand All @@ -472,13 +474,14 @@ def evaluate(self, Xn):
if semantic_weight: f[:,-3:] /= (ns * semantic_weight)

best_idx, best_val = None, 0.0
for idx, (bp, bm, bs, bh) in self.iterate_batches(patches, norms_m, norms_s, history):
for idx, (bp, bm, bs, bh) in self.iterate_batches(patches, norms_m, norms_s, history, batch_size=layer.num_filters):

weights = bp.astype(np.float32)
weights[:,:-3] /= (bm * 3.0) # TODO: Use exact number of channels.
if semantic_weight: weights[:,-3:] /= (bs * semantic_weight)
layer.W.set_value(weights)

cur_idx, cur_val, cur_match = self.compute_matches[l](f, history[idx])
cur_idx, cur_val, cur_match = self.compute_matches[l](f) #, history[idx])
if best_idx is None:
best_idx = cur_idx
best_val = cur_val
Expand Down

0 comments on commit 9808696

Please sign in to comment.