Skip to content

Commit 3607865

Browse files
authored
Merge pull request #12 from arnaud-roussel/init_tweaks
Better docstring. Image size and device found in forward.
2 parents e4cc619 + 6d61754 commit 3607865

File tree

1 file changed

+58
-36
lines changed

1 file changed

+58
-36
lines changed

convlstm.py

Lines changed: 58 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import torch.nn as nn
2-
from torch.autograd import Variable
32
import torch
43

54

65
class ConvLSTMCell(nn.Module):
76

8-
def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
7+
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
98
"""
109
Initialize ConvLSTM cell.
11-
10+
1211
Parameters
1312
----------
14-
input_size: (int, int)
15-
Height and width of input tensor as (height, width).
1613
input_dim: int
1714
Number of channels of input tensor.
1815
hidden_dim: int
@@ -25,60 +22,83 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
2522

2623
super(ConvLSTMCell, self).__init__()
2724

28-
self.height, self.width = input_size
29-
self.input_dim = input_dim
25+
self.input_dim = input_dim
3026
self.hidden_dim = hidden_dim
3127

3228
self.kernel_size = kernel_size
33-
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
34-
self.bias = bias
35-
29+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
30+
self.bias = bias
31+
3632
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
3733
out_channels=4 * self.hidden_dim,
3834
kernel_size=self.kernel_size,
3935
padding=self.padding,
4036
bias=self.bias)
4137

4238
def forward(self, input_tensor, cur_state):
43-
4439
h_cur, c_cur = cur_state
45-
40+
4641
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
47-
42+
4843
combined_conv = self.conv(combined)
49-
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
44+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
5045
i = torch.sigmoid(cc_i)
5146
f = torch.sigmoid(cc_f)
5247
o = torch.sigmoid(cc_o)
5348
g = torch.tanh(cc_g)
5449

5550
c_next = f * c_cur + i * g
5651
h_next = o * torch.tanh(c_next)
57-
52+
5853
return h_next, c_next
5954

60-
def init_hidden(self, batch_size):
61-
return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
62-
Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda())
55+
def init_hidden(self, batch_size, image_size):
56+
height, width = image_size
57+
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
58+
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
6359

6460

6561
class ConvLSTM(nn.Module):
6662

67-
def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
63+
"""
64+
65+
Parameters:
66+
input_dim: Number of channels in input
67+
hidden_dim: Number of hidden channels
68+
kernel_size: Size of kernel in convolutions
69+
num_layers: Number of LSTM layers stacked on each other
70+
batch_first: Whether or not dimension 0 is the batch or not
71+
bias: Bias or no bias in Convolution
72+
return_all_layers: Return the list of computations for all layers
73+
Note: Will do same padding.
74+
75+
Input:
76+
A tensor of size B, T, C, H, W or T, B, C, H, W
77+
Output:
78+
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
79+
0 - layer_output_list is the list of lists of length T of each output
80+
1 - last_state_list is the list of last states
81+
each element of the list is a tuple (h, c) for hidden state and memory
82+
Example:
83+
>> x = torch.rand((32, 10, 64, 128, 128))
84+
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
85+
>> _, last_states = convlstm(x)
86+
>> h = last_states[0][0] # 0 for layer index, 0 for h index
87+
"""
88+
89+
def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
6890
batch_first=False, bias=True, return_all_layers=False):
6991
super(ConvLSTM, self).__init__()
7092

7193
self._check_kernel_size_consistency(kernel_size)
7294

7395
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
7496
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
75-
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
97+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
7698
if not len(kernel_size) == len(hidden_dim) == num_layers:
7799
raise ValueError('Inconsistent list length.')
78100

79-
self.height, self.width = input_size
80-
81-
self.input_dim = input_dim
101+
self.input_dim = input_dim
82102
self.hidden_dim = hidden_dim
83103
self.kernel_size = kernel_size
84104
self.num_layers = num_layers
@@ -88,10 +108,9 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
88108

89109
cell_list = []
90110
for i in range(0, self.num_layers):
91-
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]
111+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
92112

93-
cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
94-
input_dim=cur_input_dim,
113+
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
95114
hidden_dim=self.hidden_dim[i],
96115
kernel_size=self.kernel_size[i],
97116
bias=self.bias))
@@ -100,14 +119,14 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
100119

101120
def forward(self, input_tensor, hidden_state=None):
102121
"""
103-
122+
104123
Parameters
105124
----------
106-
input_tensor: todo
125+
input_tensor: todo
107126
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
108127
hidden_state: todo
109128
None. todo implement stateful
110-
129+
111130
Returns
112131
-------
113132
last_state_list, layer_output
@@ -116,14 +135,18 @@ def forward(self, input_tensor, hidden_state=None):
116135
# (t, b, c, h, w) -> (b, t, c, h, w)
117136
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
118137

138+
b, _, _, h, w = input_tensor.size()
139+
119140
# Implement stateful ConvLSTM
120141
if hidden_state is not None:
121142
raise NotImplementedError()
122143
else:
123-
hidden_state = self._init_hidden(batch_size=input_tensor.size(0))
144+
# Since the init is done in forward. Can send image size here
145+
hidden_state = self._init_hidden(batch_size=b,
146+
image_size=(h, w))
124147

125148
layer_output_list = []
126-
last_state_list = []
149+
last_state_list = []
127150

128151
seq_len = input_tensor.size(1)
129152
cur_layer_input = input_tensor
@@ -133,7 +156,6 @@ def forward(self, input_tensor, hidden_state=None):
133156
h, c = hidden_state[layer_idx]
134157
output_inner = []
135158
for t in range(seq_len):
136-
137159
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
138160
cur_state=[h, c])
139161
output_inner.append(h)
@@ -146,20 +168,20 @@ def forward(self, input_tensor, hidden_state=None):
146168

147169
if not self.return_all_layers:
148170
layer_output_list = layer_output_list[-1:]
149-
last_state_list = last_state_list[-1:]
171+
last_state_list = last_state_list[-1:]
150172

151173
return layer_output_list, last_state_list
152174

153-
def _init_hidden(self, batch_size):
175+
def _init_hidden(self, batch_size, image_size):
154176
init_states = []
155177
for i in range(self.num_layers):
156-
init_states.append(self.cell_list[i].init_hidden(batch_size))
178+
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
157179
return init_states
158180

159181
@staticmethod
160182
def _check_kernel_size_consistency(kernel_size):
161183
if not (isinstance(kernel_size, tuple) or
162-
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
184+
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
163185
raise ValueError('`kernel_size` must be tuple or list of tuples')
164186

165187
@staticmethod

0 commit comments

Comments
 (0)