Skip to content

Commit

Permalink
Remove reflection layer, important but not the bottleneck right now a…
Browse files Browse the repository at this point in the history
…nd quite slower.
  • Loading branch information
alexjc committed Nov 4, 2016
1 parent 9bba985 commit da354f3
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,34 +242,6 @@ def get_output_for(self, input, deterministic=False, **kwargs):
return out


class ReflectLayer(lasagne.layers.Layer):
"""Based on more code by ajbrock: https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f
"""

def __init__(self, incoming, pad, batch_ndim=2, **kwargs):
super(ReflectLayer, self).__init__(incoming, **kwargs)
self.pad = (pad, pad)
self.batch_ndim = batch_ndim

def get_output_shape_for(self, input_shape):
output_shape = list(input_shape)
for k, p in enumerate(self.pad):
if output_shape[k + self.batch_ndim] is None: continue
l, r = p, p
output_shape[k + self.batch_ndim] += l + r
return tuple(output_shape)

def get_output_for(self, x, **kwargs):
out = T.zeros(self.get_output_shape_for(x.shape))
p0, p1 = self.pad
out = T.set_subtensor(out[:,:,:p0,p1:-p1], x[:,:,p0:0:-1,:])
out = T.set_subtensor(out[:,:,-p0:,p1:-p1], x[:,:,-2:-(2+p0):-1,:])
out = T.set_subtensor(out[:,:,p0:-p0,p1:-p1], x)
out = T.set_subtensor(out[:,:,:,:p1], out[:,:,:,(2*p1):p1:-1])
out = T.set_subtensor(out[:,:,:,-p1:], out[:,:,:,-(p1+2):-(2*p1+2):-1])
return out


class Model(object):

def __init__(self):
Expand All @@ -296,8 +268,7 @@ def last_layer(self):
return list(self.network.values())[-1]

def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
reflected = ReflectLayer(input, pad=pad[0]) if pad[0] > 0 else input
conv = ConvLayer(reflected, units, filter_size, stride=stride, pad=(0,0), nonlinearity=None)
conv = ConvLayer(input, units, filter_size, stride=stride, pad=(0,0), nonlinearity=None)
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
self.network[name+'x'] = conv
self.network[name+'>'] = prelu
Expand Down

0 comments on commit da354f3

Please sign in to comment.