Skip to content

Commit 8645cb9

Browse files
committed
Add Model02 recurrent generator: Model02RG
1 parent 8cbd334 commit 8645cb9

File tree

3 files changed

+154
-14
lines changed

3 files changed

+154
-14
lines changed

main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def main():
108108
from model.Model01 import Model01 as Model
109109
elif args.model == 'model_02':
110110
from model.Model02 import Model02 as Model
111+
elif args.model == 'model_02_rg':
112+
from model.Model02 import Model02RG as Model
111113
else:
112114
print('\n{:#^80}\n'.format(' Please select a valid model '))
113115
exit()
@@ -164,8 +166,14 @@ def adjust_learning_rate(opt, epoch):
164166
def selective_zero(s, new):
165167
if new.any(): # if at least one video changed
166168
b = new.nonzero().squeeze(1) # get the list of indices
167-
for layer in range(len(s)): # for every layer having a state
168-
s[layer] = s[layer].index_fill(0, V(b), 0) # mask state, zero selected indices
169+
if isinstance(s[0], list): # recurrent G
170+
for layer in range(len(s[0])): # for every layer having a state
171+
s[0][layer] = s[0][layer].index_fill(0, V(b), 0) # mask state, zero selected indices
172+
for layer in range(len(s[1])): # for every layer having a state
173+
s[1][layer] = s[1][layer].index_fill(0, V(b), 0) # mask state, zero selected indices
174+
else: # simple convolutive G
175+
for layer in range(len(s)): # for every layer having a state
176+
s[layer] = s[layer].index_fill(0, V(b), 0) # mask state, zero selected indices
169177

170178

171179
def selective_match(x_hat, x, new):

model/Model02.py

Lines changed: 112 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
# Define some constants
9+
from model.RG import RG
10+
911
KERNEL_SIZE = 3
1012
PADDING = KERNEL_SIZE // 2
1113
KERNEL_STRIDE = 2
@@ -14,7 +16,7 @@
1416

1517
class Model02(nn.Module):
1618
"""
17-
Generate a constructor for model_01 type of network
19+
Generate a constructor for model_02 type of network
1820
"""
1921

2022
def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
@@ -29,7 +31,7 @@ def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
2931
super().__init__()
3032
self.hidden_layers = len(network_size) - 2
3133

32-
print('\n{:-^80}'.format(' Building model '))
34+
print('\n{:-^80}'.format(' Building model Model02 '))
3335
print('Hidden layers:', self.hidden_layers)
3436
print('Net sizing:', network_size)
3537
print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size))
@@ -91,21 +93,114 @@ def forward(self, x, state):
9193
return (x, state), (x_mean, video_index)
9294

9395

94-
def _test_model():
96+
class Model02RG(nn.Module):
97+
"""
98+
Generate a constructor for model_02_rg type of network
99+
"""
100+
101+
def __init__(self, network_size: tuple, input_spatial_size: tuple) -> None:
102+
"""
103+
Initialise Model02RG constructor
104+
105+
:param network_size: (n, h1, h2, ..., emb_size, nb_videos)
106+
:type network_size: tuple
107+
:param input_spatial_size: (height, width)
108+
:type input_spatial_size: tuple
109+
"""
110+
super().__init__()
111+
self.hidden_layers = len(network_size) - 2
112+
113+
print('\n{:-^80}'.format(' Building model Model02RG '))
114+
print('Hidden layers:', self.hidden_layers)
115+
print('Net sizing:', network_size)
116+
print('Input spatial size: {} x {}'.format(network_size[0], input_spatial_size))
117+
118+
# main auto-encoder blocks
119+
self.activation_size = [input_spatial_size]
120+
for layer in range(0, self.hidden_layers):
121+
# print some annotation when building model
122+
print('{:-<80}'.format('Layer ' + str(layer + 1) + ' '))
123+
print('Bottom size: {} x {}'.format(network_size[layer], self.activation_size[-1]))
124+
self.activation_size.append(tuple(ceil(s / 2) for s in self.activation_size[layer]))
125+
print('Top size: {} x {}'.format(network_size[layer + 1], self.activation_size[-1]))
126+
127+
# init D (discriminative) blocks
128+
multiplier = layer and 2 or 1 # D_n, n > 1, has intra-layer feedback
129+
setattr(self, 'D_' + str(layer + 1), nn.Conv2d(
130+
in_channels=network_size[layer] * multiplier, out_channels=network_size[layer + 1],
131+
kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
132+
))
133+
setattr(self, 'BN_D_' + str(layer + 1), nn.BatchNorm2d(network_size[layer + 1]))
134+
135+
# init G (generative) blocks
136+
setattr(self, 'G_' + str(layer + 1), RG(
137+
in_channels=network_size[layer + 1], out_channels=network_size[layer],
138+
kernel_size=KERNEL_SIZE, stride=KERNEL_STRIDE, padding=PADDING
139+
))
140+
setattr(self, 'BN_G_' + str(layer + 1), nn.BatchNorm2d(network_size[layer]))
141+
142+
# init auxiliary classifier
143+
print('{:-<80}'.format('Classifier '))
144+
print(network_size[-2], '-->', network_size[-1])
145+
self.average = nn.AvgPool2d(self.activation_size[-1])
146+
self.stabiliser = nn.Linear(network_size[-2], network_size[-1])
147+
print(80 * '-', end='\n\n')
148+
149+
def forward(self, x, state):
150+
activation_sizes = [x.size()] # start from the input
151+
residuals = list()
152+
# state[0] --> network layer state; state[1] --> generative state
153+
state = state or [[None] * (self.hidden_layers - 1), [None] * self.hidden_layers]
154+
for layer in range(0, self.hidden_layers): # connect discriminative blocks
155+
if layer: # concat the input with the state for D_n, n > 1
156+
s = state[0][layer - 1] or V(x.data.clone().zero_())
157+
x = torch.cat((x, s), 1)
158+
x = getattr(self, 'D_' + str(layer + 1))(x)
159+
residuals.append(x)
160+
x = f.relu(x)
161+
x = getattr(self, 'BN_D_' + str(layer + 1))(x)
162+
activation_sizes.append(x.size()) # cache output size for later retrieval
163+
for layer in reversed(range(0, self.hidden_layers)): # connect generative blocks
164+
x = getattr(self, 'G_' + str(layer + 1))((x, activation_sizes[layer]), state[1][layer])
165+
state[1][layer] = x # h[t - 1] <- h[t]
166+
if layer:
167+
state[0][layer - 1] = x
168+
x += residuals[layer - 1]
169+
x = f.relu(x)
170+
x = getattr(self, 'BN_G_' + str(layer + 1))(x)
171+
x_mean = self.average(residuals[-1])
172+
video_index = self.stabiliser(x_mean.view(x_mean.size(0), -1))
173+
174+
return (x, state), (x_mean, video_index)
175+
176+
177+
def _test_models():
178+
_test_model(Model02)
179+
_test_model(Model02RG)
180+
181+
182+
def _test_model(Model):
95183
big_t = 2
96184
x = torch.rand(big_t + 1, 1, 3, 4 * 2**3 + 3, 6 * 2**3 + 5)
97185
big_k = 10
98186
y = torch.LongTensor(big_t, 1).random_(big_k)
99-
model_01 = Model02(network_size=(3, 6, 12, 18, big_k), input_spatial_size=x[0].size()[2:])
187+
model = Model(network_size=(3, 6, 12, 18, big_k), input_spatial_size=x[0].size()[2:])
100188

101189
state = None
102-
(x_hat, state), (emb, idx) = model_01(V(x[0]), state)
190+
(x_hat, state), (emb, idx) = model(V(x[0]), state)
103191

104192
print('Input size:', tuple(x.size()))
105193
print('Output size:', tuple(x_hat.data.size()))
106194
print('Video index size:', tuple(idx.size()))
107195
for i, s in enumerate(state):
108-
print('State', i + 1, 'has size:', tuple(s.size()))
196+
if isinstance(s, list):
197+
for i, s in enumerate(state[0]):
198+
print('Net state', i + 1, 'has size:', tuple(s.size()))
199+
for i, s in enumerate(state[1]):
200+
print('G', i + 1, 'state has size:', tuple(s.size()))
201+
break
202+
else:
203+
print('State', i + 1, 'has size:', tuple(s.size()))
109204
print('Embedding has size:', emb.data.numel())
110205

111206
mse = nn.MSELoss()
@@ -118,7 +213,7 @@ def _test_model():
118213
show_graph(loss_t1)
119214

120215
# run one more time
121-
(x_hat, _), (_, idx) = model_01(V(x[1]), state)
216+
(x_hat, _), (_, idx) = model(V(x[1]), state)
122217

123218
x_next = V(x[2])
124219
y_var = V(y[1])
@@ -128,7 +223,12 @@ def _test_model():
128223
show_graph(loss_tot)
129224

130225

131-
def _test_training():
226+
def _test_training_models():
227+
_test_training(Model02)
228+
_test_training(Model02RG)
229+
230+
231+
def _test_training(Model):
132232
big_k = 10 # number of training videos
133233
network_size = (3, 6, 12, 18, big_k)
134234
big_t = 6 # sequence length
@@ -147,7 +247,7 @@ def _test_training():
147247
print('Target index has size', tuple(y.size()))
148248

149249
print('Define model')
150-
model = Model02(network_size=network_size, input_spatial_size=x[0].size()[2:])
250+
model = Model(network_size=network_size, input_spatial_size=x[0].size()[2:])
151251

152252
print('Create a MSE and NLL criterions')
153253
mse = nn.MSELoss()
@@ -175,13 +275,13 @@ def _test_training():
175275

176276

177277
if __name__ == '__main__':
178-
_test_model()
179-
_test_training()
278+
_test_models()
279+
_test_training_models()
180280

181281

182282
__author__ = "Alfredo Canziani"
183283
__credits__ = ["Alfredo Canziani"]
184284
__maintainer__ = "Alfredo Canziani"
185285
__email__ = "alfredo.canziani@gmail.com"
186286
__status__ = "Production" # "Prototype", "Development", or "Production"
187-
__date__ = "Feb 17"
287+
__date__ = "Feb, Mar 17"

model/RG.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from torch import nn
2+
3+
4+
class RG(nn.Module):
5+
"""Recurrent Generative Module"""
6+
7+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
8+
""" Initialise RG Module (parameters as nn.ConvTranspose2d)"""
9+
super().__init__()
10+
self.from_input = nn.ConvTranspose2d(
11+
in_channels=in_channels, out_channels=out_channels,
12+
kernel_size=kernel_size, stride=stride, padding=padding
13+
)
14+
self.from_state = nn.Conv2d(
15+
in_channels=out_channels, out_channels=out_channels,
16+
kernel_size=kernel_size, padding=padding, bias=False
17+
)
18+
19+
def forward(self, x, state):
20+
"""
21+
Calling signature
22+
23+
:param x: (input, output_size)
24+
:type x: tuple
25+
:param state: previous output
26+
:type state: torch.Tensor
27+
:return: current state
28+
:rtype: torch.Tensor
29+
"""
30+
x = self.from_input(*x) # the very first x is a tuple (input, expected_output_size)
31+
if state: x += self.from_state(state)
32+
return x

0 commit comments

Comments
 (0)