Skip to content

Commit 53071ca

Browse files
committed
Add ConvLSTM Module
1 parent af784db commit 53071ca

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

convlstm.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import torch.nn as nn
2+
from torch.autograd import Variable
3+
import torch
4+
5+
6+
class ConvLSTMCell(nn.Module):
7+
8+
def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
9+
"""
10+
Initialize ConvLSTM cell.
11+
12+
Parameters
13+
----------
14+
input_size: (int, int)
15+
Height and width of input tensor as (height, width).
16+
input_dim: int
17+
Number of channels of input tensor.
18+
hidden_dim: int
19+
Number of channels of hidden state.
20+
kernel_size: (int, int)
21+
Size of the convolutional kernel.
22+
bias: bool
23+
Whether or not to add the bias.
24+
"""
25+
26+
super(ConvLSTMCell, self).__init__()
27+
28+
self.height, self.width = input_size
29+
self.input_dim = input_dim
30+
self.hidden_dim = hidden_dim
31+
32+
self.kernel_size = kernel_size
33+
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
34+
self.bias = bias
35+
36+
self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
37+
out_channels=4 * self.hidden_dim,
38+
kernel_size=self.kernel_size,
39+
padding=self.padding,
40+
bias=self.bias)
41+
42+
def forward(self, input_tensor, cur_state):
43+
44+
h_cur, c_cur = cur_state
45+
46+
combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
47+
48+
combined_conv = self.conv(combined)
49+
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
50+
i = torch.sigmoid(cc_i)
51+
f = torch.sigmoid(cc_f)
52+
o = torch.sigmoid(cc_o)
53+
g = torch.tanh(cc_g)
54+
55+
c_next = f * c_cur + i * g
56+
h_next = o * torch.tanh(c_next)
57+
58+
return h_next, c_next
59+
60+
def init_hidden(self, batch_size):
61+
return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
62+
Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda())
63+
64+
65+
class ConvLSTM(nn.Module):
66+
67+
def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
68+
batch_first=False, bias=True, return_all_layers=False):
69+
super(ConvLSTM, self).__init__()
70+
71+
self._check_kernel_size_consistency(kernel_size)
72+
73+
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
74+
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
75+
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
76+
if not len(kernel_size) == len(hidden_dim) == num_layers:
77+
raise ValueError('Inconsistent list length.')
78+
79+
self.height, self.width = input_size
80+
81+
self.input_dim = input_dim
82+
self.hidden_dim = hidden_dim
83+
self.kernel_size = kernel_size
84+
self.num_layers = num_layers
85+
self.batch_first = batch_first
86+
self.bias = bias
87+
self.return_all_layers = return_all_layers
88+
89+
cell_list = []
90+
for i in range(0, self.num_layers):
91+
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]
92+
93+
cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
94+
input_dim=cur_input_dim,
95+
hidden_dim=self.hidden_dim[i],
96+
kernel_size=self.kernel_size[i],
97+
bias=self.bias))
98+
99+
self.cell_list = nn.ModuleList(cell_list)
100+
101+
def forward(self, input_tensor, hidden_state=None):
102+
"""
103+
104+
Parameters
105+
----------
106+
input_tensor: todo
107+
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
108+
hidden_state: todo
109+
None. todo implement stateful
110+
111+
Returns
112+
-------
113+
last_state_list, layer_output
114+
"""
115+
if not self.batch_first:
116+
# (t, b, c, h, w) -> (b, t, c, h, w)
117+
input_tensor.permute(1, 0, 2, 3, 4)
118+
119+
# Implement stateful ConvLSTM
120+
if hidden_state is not None:
121+
raise NotImplementedError()
122+
else:
123+
hidden_state = self._init_hidden(batch_size=input_tensor.size(0))
124+
125+
layer_output_list = []
126+
last_state_list = []
127+
128+
seq_len = input_tensor.size(1)
129+
cur_layer_input = input_tensor
130+
131+
for layer_idx in range(self.num_layers):
132+
133+
h, c = hidden_state[layer_idx]
134+
output_inner = []
135+
for t in range(seq_len):
136+
137+
h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
138+
cur_state=[h, c])
139+
output_inner.append(h)
140+
141+
layer_output = torch.stack(output_inner, dim=1)
142+
cur_layer_input = layer_output
143+
144+
layer_output_list.append(layer_output)
145+
last_state_list.append([h, c])
146+
147+
if not self.return_all_layers:
148+
layer_output_list = layer_output_list[-1:]
149+
last_state_list = last_state_list[-1:]
150+
151+
return layer_output_list, last_state_list
152+
153+
def _init_hidden(self, batch_size):
154+
init_states = []
155+
for i in range(self.num_layers):
156+
init_states.append(self.cell_list[i].init_hidden(batch_size))
157+
return init_states
158+
159+
@staticmethod
160+
def _check_kernel_size_consistency(kernel_size):
161+
if not (isinstance(kernel_size, tuple) or
162+
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
163+
raise ValueError('`kernel_size` must be tuple or list of tuples')
164+
165+
@staticmethod
166+
def _extend_for_multilayer(param, num_layers):
167+
if not isinstance(param, list):
168+
param = [param] * num_layers
169+
return param

0 commit comments

Comments
 (0)