Skip to content

Commit b4bc50a

Browse files
authored
Update train.py
1 parent 7843f5e commit b4bc50a

File tree

1 file changed

+1
-146
lines changed

1 file changed

+1
-146
lines changed

train/train.py

Lines changed: 1 addition & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import mlp.data_providers as data_providers
66

77
from image_processing import *
8+
from layers import *
89

910

1011

@@ -32,152 +33,6 @@
3233

3334

3435

35-
# define mixed max-average pooling layer
36-
def mixed_pooling (inputs, alpha, size=2):
37-
"""Mixed pooling operation, nonresponsive
38-
Combine max pooling and average pooling in fixed proportion specified by alpha a:
39-
f mixed (x) = a * f max(x) + (1-a) * f avg(x)
40-
41-
arguments:
42-
inputs: tensor of shape [batch size, height, width, channels]
43-
size: an integer, width and height of the pooling filter
44-
alpha: the scalar mixing proportion of range [0,1]
45-
return:
46-
outputs: tensor of shape [batch_size, height//size, width//size, channels]
47-
48-
"""
49-
if alpha == -1:
50-
alpha = tf.Variable(0.0)
51-
x1 = tf.contrib.layers.max_pool2d(inputs=inputs, kernel_size=[size, size], stride=2, padding='VALID')
52-
x2 = tf.contrib.layers.avg_pool2d(inputs=inputs, kernel_size=[size, size], stride=2, padding='VALID')
53-
outputs = tf.add(tf.multiply(x1, alpha), tf.multiply(x2, (1-alpha)))
54-
55-
return [alpha, outputs]
56-
57-
58-
59-
60-
61-
# define gated max-average pooling lyaer
62-
def gated_pooling(inputs, filter, size=2, learn_option='l/c'):
63-
"""Gated pooling operation, responsive
64-
Combine max pooling and average pooling in a mixing proportion,
65-
which is obtained from the inner product between the gating mask and the region being
66-
pooled and then fed through a sigmoid:
67-
fgate(x) = sigmoid(w*x)* fmax(x) + (1-sigmoid(w*x))* favg(x)
68-
69-
arguments:
70-
inputs: input of shape [batch size, height, width, channels]
71-
filter: filter size of the input layer, used to initialize gating mask
72-
size: an integer, width and height of the pooling filter
73-
learn_option: learning options of gated pooling, include:
74-
'l/c': learn a mask per layer/channel
75-
'l/r/c': learn a mask per layer/pooling region/channel combined
76-
return:
77-
outputs: tensor with the shape of [batch_size, height//size, width//size, channels]
78-
79-
"""
80-
if learn_option == 'l':
81-
gating_mask = all_channel_connected2d(inputs)
82-
if learn_option == 'l/c':
83-
w_gated = tf.Variable(tf.truncated_normal([size,size,filter,filter], stddev=2/(size*size*filter*2)**0.5))
84-
gating_mask = tf.nn.conv2d(inputs, w_gated, strides=[1,size,size,1], padding='VALID')
85-
if learn_option == 'l/r/c':
86-
gating_mask = locally_connected2d(inputs)
87-
88-
alpha = tf.sigmoid(gating_mask)
89-
90-
x1 = tf.contrib.layers.max_pool2d(inputs=inputs, kernel_size=[size, size], stride=2, padding='VALID')
91-
x2 = tf.contrib.layers.avg_pool2d(inputs=inputs, kernel_size=[size, size],stride=2, padding='VALID')
92-
outputs = tf.add(tf.multiply(x1, alpha), tf.multiply(x2, (1-alpha)))
93-
return outputs
94-
95-
96-
97-
98-
99-
#locally connected layer (unshared-weights conv, layer),
100-
# designed for gated pooling, learn a param "per layer/region/channel"
101-
def locally_connected2d(x, size = 2):
102-
"""
103-
The `LocallyConnected2D` layer works similarly
104-
to the `Convolution2D` layer, except that weights are unshared,
105-
that is, a different set of filters is applied at each
106-
different patch of the input.
107-
108-
NOTE: No bias or activation function applied. No overlapping between sub-region.
109-
110-
arguments:
111-
x: 4D tensor with shape: [samples, rows, cols, channels]
112-
size: width and height of the filter, default 2x2 filter.
113-
this is also the length of stride to ensure no overlapping
114-
returns:
115-
4D tensor with shape: [samples, new_rows, new_cols, nb_filter]
116-
`rows` and `cols` values might have changed due to padding.
117-
118-
"""
119-
120-
xs = []
121-
_, input_row, input_col, nb_filter = x.get_shape().as_list()
122-
output_row = input_row //2
123-
output_col = input_col //2
124-
nb_row = size
125-
nb_col = size
126-
stride_row = size
127-
stride_col = size
128-
feature_dim = nb_row * nb_col * nb_filter
129-
130-
w_shape = (output_row * output_col,
131-
nb_row * nb_col * nb_filter,
132-
nb_filter)
133-
mask = tf.Variable(tf.truncated_normal(w_shape, stddev=2./(w_shape[0]*w_shape[1]*2)**0.5))
134-
for i in range(output_row):
135-
for j in range(output_col):
136-
slice_row = slice(i * stride_row,
137-
i * stride_row + nb_row)
138-
slice_col = slice(j * stride_col,
139-
j * stride_col + nb_col)
140-
xs.append(tf.reshape(x[:, slice_row, slice_col, :], (1, -1, feature_dim)))
141-
x_aggregate = tf.concat(0, xs)
142-
output = tf.matmul(x_aggregate, mask)
143-
output = tf.reshape(output, (output_row, output_col, -1, nb_filter))
144-
output = tf.transpose(output, perm=[2, 0, 1, 3])
145-
146-
return output
147-
148-
149-
150-
151-
#design for gated pooling, learn a param "per layer" option
152-
def all_channel_connected2d(x, size=2):
153-
"""
154-
The all channel connected layer is a modified version of
155-
Convolutional layer,
156-
which shares the same weights not only between each patch,
157-
but also between all channels of the layer input. That is,
158-
the whole layer only has one filter
159-
160-
NOTE: 'VALID', no bias, no activation function.
161-
162-
arguments:
163-
x: 4D tensor with shape: [batch_size, rows, cols, channels]
164-
size: width and height of the filter, default 2x2 filter.
165-
this is also the length of stride to ensure no overlapping
166-
returns:
167-
4D tensor with shape: [batch_size, new_rows, new_cols, nb_filter]
168-
"""
169-
170-
nb_batch, input_row, input_col, nb_filter = x.get_shape().as_list()
171-
output_size = input_row //2
172-
mask = tf.Variable(tf.truncated_normal([size,size,1,1], stddev=2./(size*size*2)**0.5))
173-
174-
xs = []
175-
for c in tf.split(x, nb_filter, 3):
176-
xs.append(tf.nn.conv2d(c, mask, strides=[1,1,1,1], padding='VALID'))
177-
output = tf.reshape(x, [nb_batch, output_size, output_size, nb_filter])
178-
179-
return output
180-
18136

18237

18338

0 commit comments

Comments
 (0)