1
1
import torch .nn as nn
2
- from torch .autograd import Variable
3
2
import torch
4
3
5
4
6
5
class ConvLSTMCell (nn .Module ):
7
6
8
- def __init__ (self , input_size , input_dim , hidden_dim , kernel_size , bias ):
7
+ def __init__ (self , input_dim , hidden_dim , kernel_size , bias ):
9
8
"""
10
9
Initialize ConvLSTM cell.
11
-
10
+
12
11
Parameters
13
12
----------
14
- input_size: (int, int)
15
- Height and width of input tensor as (height, width).
16
13
input_dim: int
17
14
Number of channels of input tensor.
18
15
hidden_dim: int
@@ -25,60 +22,83 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
25
22
26
23
super (ConvLSTMCell , self ).__init__ ()
27
24
28
- self .height , self .width = input_size
29
- self .input_dim = input_dim
25
+ self .input_dim = input_dim
30
26
self .hidden_dim = hidden_dim
31
27
32
28
self .kernel_size = kernel_size
33
- self .padding = kernel_size [0 ] // 2 , kernel_size [1 ] // 2
34
- self .bias = bias
35
-
29
+ self .padding = kernel_size [0 ] // 2 , kernel_size [1 ] // 2
30
+ self .bias = bias
31
+
36
32
self .conv = nn .Conv2d (in_channels = self .input_dim + self .hidden_dim ,
37
33
out_channels = 4 * self .hidden_dim ,
38
34
kernel_size = self .kernel_size ,
39
35
padding = self .padding ,
40
36
bias = self .bias )
41
37
42
38
def forward (self , input_tensor , cur_state ):
43
-
44
39
h_cur , c_cur = cur_state
45
-
40
+
46
41
combined = torch .cat ([input_tensor , h_cur ], dim = 1 ) # concatenate along channel axis
47
-
42
+
48
43
combined_conv = self .conv (combined )
49
- cc_i , cc_f , cc_o , cc_g = torch .split (combined_conv , self .hidden_dim , dim = 1 )
44
+ cc_i , cc_f , cc_o , cc_g = torch .split (combined_conv , self .hidden_dim , dim = 1 )
50
45
i = torch .sigmoid (cc_i )
51
46
f = torch .sigmoid (cc_f )
52
47
o = torch .sigmoid (cc_o )
53
48
g = torch .tanh (cc_g )
54
49
55
50
c_next = f * c_cur + i * g
56
51
h_next = o * torch .tanh (c_next )
57
-
52
+
58
53
return h_next , c_next
59
54
60
- def init_hidden (self , batch_size ):
61
- return (Variable (torch .zeros (batch_size , self .hidden_dim , self .height , self .width )).cuda (),
62
- Variable (torch .zeros (batch_size , self .hidden_dim , self .height , self .width )).cuda ())
55
+ def init_hidden (self , batch_size , image_size ):
56
+ height , width = image_size
57
+ return (torch .zeros (batch_size , self .hidden_dim , height , width , device = self .conv .weight .device ),
58
+ torch .zeros (batch_size , self .hidden_dim , height , width , device = self .conv .weight .device ))
63
59
64
60
65
61
class ConvLSTM (nn .Module ):
66
62
67
- def __init__ (self , input_size , input_dim , hidden_dim , kernel_size , num_layers ,
63
+ """
64
+
65
+ Parameters:
66
+ input_dim: Number of channels in input
67
+ hidden_dim: Number of hidden channels
68
+ kernel_size: Size of kernel in convolutions
69
+ num_layers: Number of LSTM layers stacked on each other
70
+ batch_first: Whether or not dimension 0 is the batch or not
71
+ bias: Bias or no bias in Convolution
72
+ return_all_layers: Return the list of computations for all layers
73
+ Note: Will do same padding.
74
+
75
+ Input:
76
+ A tensor of size B, T, C, H, W or T, B, C, H, W
77
+ Output:
78
+ A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
79
+ 0 - layer_output_list is the list of lists of length T of each output
80
+ 1 - last_state_list is the list of last states
81
+ each element of the list is a tuple (h, c) for hidden state and memory
82
+ Example:
83
+ >> x = torch.rand((32, 10, 64, 128, 128))
84
+ >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
85
+ >> _, last_states = convlstm(x)
86
+ >> h = last_states[0][0] # 0 for layer index, 0 for h index
87
+ """
88
+
89
+ def __init__ (self , input_dim , hidden_dim , kernel_size , num_layers ,
68
90
batch_first = False , bias = True , return_all_layers = False ):
69
91
super (ConvLSTM , self ).__init__ ()
70
92
71
93
self ._check_kernel_size_consistency (kernel_size )
72
94
73
95
# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
74
96
kernel_size = self ._extend_for_multilayer (kernel_size , num_layers )
75
- hidden_dim = self ._extend_for_multilayer (hidden_dim , num_layers )
97
+ hidden_dim = self ._extend_for_multilayer (hidden_dim , num_layers )
76
98
if not len (kernel_size ) == len (hidden_dim ) == num_layers :
77
99
raise ValueError ('Inconsistent list length.' )
78
100
79
- self .height , self .width = input_size
80
-
81
- self .input_dim = input_dim
101
+ self .input_dim = input_dim
82
102
self .hidden_dim = hidden_dim
83
103
self .kernel_size = kernel_size
84
104
self .num_layers = num_layers
@@ -88,10 +108,9 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
88
108
89
109
cell_list = []
90
110
for i in range (0 , self .num_layers ):
91
- cur_input_dim = self .input_dim if i == 0 else self .hidden_dim [i - 1 ]
111
+ cur_input_dim = self .input_dim if i == 0 else self .hidden_dim [i - 1 ]
92
112
93
- cell_list .append (ConvLSTMCell (input_size = (self .height , self .width ),
94
- input_dim = cur_input_dim ,
113
+ cell_list .append (ConvLSTMCell (input_dim = cur_input_dim ,
95
114
hidden_dim = self .hidden_dim [i ],
96
115
kernel_size = self .kernel_size [i ],
97
116
bias = self .bias ))
@@ -100,14 +119,14 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
100
119
101
120
def forward (self , input_tensor , hidden_state = None ):
102
121
"""
103
-
122
+
104
123
Parameters
105
124
----------
106
- input_tensor: todo
125
+ input_tensor: todo
107
126
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
108
127
hidden_state: todo
109
128
None. todo implement stateful
110
-
129
+
111
130
Returns
112
131
-------
113
132
last_state_list, layer_output
@@ -116,14 +135,18 @@ def forward(self, input_tensor, hidden_state=None):
116
135
# (t, b, c, h, w) -> (b, t, c, h, w)
117
136
input_tensor = input_tensor .permute (1 , 0 , 2 , 3 , 4 )
118
137
138
+ b , _ , _ , h , w = input_tensor .size ()
139
+
119
140
# Implement stateful ConvLSTM
120
141
if hidden_state is not None :
121
142
raise NotImplementedError ()
122
143
else :
123
- hidden_state = self ._init_hidden (batch_size = input_tensor .size (0 ))
144
+ # Since the init is done in forward. Can send image size here
145
+ hidden_state = self ._init_hidden (batch_size = b ,
146
+ image_size = (h , w ))
124
147
125
148
layer_output_list = []
126
- last_state_list = []
149
+ last_state_list = []
127
150
128
151
seq_len = input_tensor .size (1 )
129
152
cur_layer_input = input_tensor
@@ -133,7 +156,6 @@ def forward(self, input_tensor, hidden_state=None):
133
156
h , c = hidden_state [layer_idx ]
134
157
output_inner = []
135
158
for t in range (seq_len ):
136
-
137
159
h , c = self .cell_list [layer_idx ](input_tensor = cur_layer_input [:, t , :, :, :],
138
160
cur_state = [h , c ])
139
161
output_inner .append (h )
@@ -146,20 +168,20 @@ def forward(self, input_tensor, hidden_state=None):
146
168
147
169
if not self .return_all_layers :
148
170
layer_output_list = layer_output_list [- 1 :]
149
- last_state_list = last_state_list [- 1 :]
171
+ last_state_list = last_state_list [- 1 :]
150
172
151
173
return layer_output_list , last_state_list
152
174
153
- def _init_hidden (self , batch_size ):
175
+ def _init_hidden (self , batch_size , image_size ):
154
176
init_states = []
155
177
for i in range (self .num_layers ):
156
- init_states .append (self .cell_list [i ].init_hidden (batch_size ))
178
+ init_states .append (self .cell_list [i ].init_hidden (batch_size , image_size ))
157
179
return init_states
158
180
159
181
@staticmethod
160
182
def _check_kernel_size_consistency (kernel_size ):
161
183
if not (isinstance (kernel_size , tuple ) or
162
- (isinstance (kernel_size , list ) and all ([isinstance (elem , tuple ) for elem in kernel_size ]))):
184
+ (isinstance (kernel_size , list ) and all ([isinstance (elem , tuple ) for elem in kernel_size ]))):
163
185
raise ValueError ('`kernel_size` must be tuple or list of tuples' )
164
186
165
187
@staticmethod
0 commit comments