6
6
7
7
8
8
# Define some constants
9
+ from model .RG import RG
10
+
9
11
KERNEL_SIZE = 3
10
12
PADDING = KERNEL_SIZE // 2
11
13
KERNEL_STRIDE = 2
14
16
15
17
class Model02 (nn .Module ):
16
18
"""
17
- Generate a constructor for model_01 type of network
19
+ Generate a constructor for model_02 type of network
18
20
"""
19
21
20
22
def __init__ (self , network_size : tuple , input_spatial_size : tuple ) -> None :
@@ -29,7 +31,7 @@ def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
29
31
super ().__init__ ()
30
32
self .hidden_layers = len (network_size ) - 2
31
33
32
- print ('\n {:-^80}' .format (' Building model ' ))
34
+ print ('\n {:-^80}' .format (' Building model Model02 ' ))
33
35
print ('Hidden layers:' , self .hidden_layers )
34
36
print ('Net sizing:' , network_size )
35
37
print ('Input spatial size: {} x {}' .format (network_size [0 ], input_spatial_size ))
@@ -91,21 +93,114 @@ def forward(self, x, state):
91
93
return (x , state ), (x_mean , video_index )
92
94
93
95
94
- def _test_model ():
96
+ class Model02RG (nn .Module ):
97
+ """
98
+ Generate a constructor for model_02_rg type of network
99
+ """
100
+
101
+ def __init__ (self , network_size : tuple , input_spatial_size : tuple ) -> None :
102
+ """
103
+ Initialise Model02RG constructor
104
+
105
+ :param network_size: (n, h1, h2, ..., emb_size, nb_videos)
106
+ :type network_size: tuple
107
+ :param input_spatial_size: (height, width)
108
+ :type input_spatial_size: tuple
109
+ """
110
+ super ().__init__ ()
111
+ self .hidden_layers = len (network_size ) - 2
112
+
113
+ print ('\n {:-^80}' .format (' Building model Model02RG ' ))
114
+ print ('Hidden layers:' , self .hidden_layers )
115
+ print ('Net sizing:' , network_size )
116
+ print ('Input spatial size: {} x {}' .format (network_size [0 ], input_spatial_size ))
117
+
118
+ # main auto-encoder blocks
119
+ self .activation_size = [input_spatial_size ]
120
+ for layer in range (0 , self .hidden_layers ):
121
+ # print some annotation when building model
122
+ print ('{:-<80}' .format ('Layer ' + str (layer + 1 ) + ' ' ))
123
+ print ('Bottom size: {} x {}' .format (network_size [layer ], self .activation_size [- 1 ]))
124
+ self .activation_size .append (tuple (ceil (s / 2 ) for s in self .activation_size [layer ]))
125
+ print ('Top size: {} x {}' .format (network_size [layer + 1 ], self .activation_size [- 1 ]))
126
+
127
+ # init D (discriminative) blocks
128
+ multiplier = layer and 2 or 1 # D_n, n > 1, has intra-layer feedback
129
+ setattr (self , 'D_' + str (layer + 1 ), nn .Conv2d (
130
+ in_channels = network_size [layer ] * multiplier , out_channels = network_size [layer + 1 ],
131
+ kernel_size = KERNEL_SIZE , stride = KERNEL_STRIDE , padding = PADDING
132
+ ))
133
+ setattr (self , 'BN_D_' + str (layer + 1 ), nn .BatchNorm2d (network_size [layer + 1 ]))
134
+
135
+ # init G (generative) blocks
136
+ setattr (self , 'G_' + str (layer + 1 ), RG (
137
+ in_channels = network_size [layer + 1 ], out_channels = network_size [layer ],
138
+ kernel_size = KERNEL_SIZE , stride = KERNEL_STRIDE , padding = PADDING
139
+ ))
140
+ setattr (self , 'BN_G_' + str (layer + 1 ), nn .BatchNorm2d (network_size [layer ]))
141
+
142
+ # init auxiliary classifier
143
+ print ('{:-<80}' .format ('Classifier ' ))
144
+ print (network_size [- 2 ], '-->' , network_size [- 1 ])
145
+ self .average = nn .AvgPool2d (self .activation_size [- 1 ])
146
+ self .stabiliser = nn .Linear (network_size [- 2 ], network_size [- 1 ])
147
+ print (80 * '-' , end = '\n \n ' )
148
+
149
+ def forward (self , x , state ):
150
+ activation_sizes = [x .size ()] # start from the input
151
+ residuals = list ()
152
+ # state[0] --> network layer state; state[1] --> generative state
153
+ state = state or [[None ] * (self .hidden_layers - 1 ), [None ] * self .hidden_layers ]
154
+ for layer in range (0 , self .hidden_layers ): # connect discriminative blocks
155
+ if layer : # concat the input with the state for D_n, n > 1
156
+ s = state [0 ][layer - 1 ] or V (x .data .clone ().zero_ ())
157
+ x = torch .cat ((x , s ), 1 )
158
+ x = getattr (self , 'D_' + str (layer + 1 ))(x )
159
+ residuals .append (x )
160
+ x = f .relu (x )
161
+ x = getattr (self , 'BN_D_' + str (layer + 1 ))(x )
162
+ activation_sizes .append (x .size ()) # cache output size for later retrieval
163
+ for layer in reversed (range (0 , self .hidden_layers )): # connect generative blocks
164
+ x = getattr (self , 'G_' + str (layer + 1 ))((x , activation_sizes [layer ]), state [1 ][layer ])
165
+ state [1 ][layer ] = x # h[t - 1] <- h[t]
166
+ if layer :
167
+ state [0 ][layer - 1 ] = x
168
+ x += residuals [layer - 1 ]
169
+ x = f .relu (x )
170
+ x = getattr (self , 'BN_G_' + str (layer + 1 ))(x )
171
+ x_mean = self .average (residuals [- 1 ])
172
+ video_index = self .stabiliser (x_mean .view (x_mean .size (0 ), - 1 ))
173
+
174
+ return (x , state ), (x_mean , video_index )
175
+
176
+
177
+ def _test_models ():
178
+ _test_model (Model02 )
179
+ _test_model (Model02RG )
180
+
181
+
182
+ def _test_model (Model ):
95
183
big_t = 2
96
184
x = torch .rand (big_t + 1 , 1 , 3 , 4 * 2 ** 3 + 3 , 6 * 2 ** 3 + 5 )
97
185
big_k = 10
98
186
y = torch .LongTensor (big_t , 1 ).random_ (big_k )
99
- model_01 = Model02 (network_size = (3 , 6 , 12 , 18 , big_k ), input_spatial_size = x [0 ].size ()[2 :])
187
+ model = Model (network_size = (3 , 6 , 12 , 18 , big_k ), input_spatial_size = x [0 ].size ()[2 :])
100
188
101
189
state = None
102
- (x_hat , state ), (emb , idx ) = model_01 (V (x [0 ]), state )
190
+ (x_hat , state ), (emb , idx ) = model (V (x [0 ]), state )
103
191
104
192
print ('Input size:' , tuple (x .size ()))
105
193
print ('Output size:' , tuple (x_hat .data .size ()))
106
194
print ('Video index size:' , tuple (idx .size ()))
107
195
for i , s in enumerate (state ):
108
- print ('State' , i + 1 , 'has size:' , tuple (s .size ()))
196
+ if isinstance (s , list ):
197
+ for i , s in enumerate (state [0 ]):
198
+ print ('Net state' , i + 1 , 'has size:' , tuple (s .size ()))
199
+ for i , s in enumerate (state [1 ]):
200
+ print ('G' , i + 1 , 'state has size:' , tuple (s .size ()))
201
+ break
202
+ else :
203
+ print ('State' , i + 1 , 'has size:' , tuple (s .size ()))
109
204
print ('Embedding has size:' , emb .data .numel ())
110
205
111
206
mse = nn .MSELoss ()
@@ -118,7 +213,7 @@ def _test_model():
118
213
show_graph (loss_t1 )
119
214
120
215
# run one more time
121
- (x_hat , _ ), (_ , idx ) = model_01 (V (x [1 ]), state )
216
+ (x_hat , _ ), (_ , idx ) = model (V (x [1 ]), state )
122
217
123
218
x_next = V (x [2 ])
124
219
y_var = V (y [1 ])
@@ -128,7 +223,12 @@ def _test_model():
128
223
show_graph (loss_tot )
129
224
130
225
131
- def _test_training ():
226
+ def _test_training_models ():
227
+ _test_training (Model02 )
228
+ _test_training (Model02RG )
229
+
230
+
231
+ def _test_training (Model ):
132
232
big_k = 10 # number of training videos
133
233
network_size = (3 , 6 , 12 , 18 , big_k )
134
234
big_t = 6 # sequence length
@@ -147,7 +247,7 @@ def _test_training():
147
247
print ('Target index has size' , tuple (y .size ()))
148
248
149
249
print ('Define model' )
150
- model = Model02 (network_size = network_size , input_spatial_size = x [0 ].size ()[2 :])
250
+ model = Model (network_size = network_size , input_spatial_size = x [0 ].size ()[2 :])
151
251
152
252
print ('Create a MSE and NLL criterions' )
153
253
mse = nn .MSELoss ()
@@ -175,13 +275,13 @@ def _test_training():
175
275
176
276
177
277
if __name__ == '__main__' :
178
- _test_model ()
179
- _test_training ()
278
+ _test_models ()
279
+ _test_training_models ()
180
280
181
281
182
282
__author__ = "Alfredo Canziani"
183
283
__credits__ = ["Alfredo Canziani" ]
184
284
__maintainer__ = "Alfredo Canziani"
185
285
__email__ = "alfredo.canziani@gmail.com"
186
286
__status__ = "Production" # "Prototype", "Development", or "Production"
187
- __date__ = "Feb 17"
287
+ __date__ = "Feb, Mar 17"
0 commit comments