Skip to content

Better docstring. Image size and device found in forward. #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 58 additions & 36 deletions convlstm.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import torch.nn as nn
from torch.autograd import Variable
import torch


class ConvLSTMCell(nn.Module):

def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):
def __init__(self, input_dim, hidden_dim, kernel_size, bias):
"""
Initialize ConvLSTM cell.

Parameters
----------
input_size: (int, int)
Height and width of input tensor as (height, width).
input_dim: int
Number of channels of input tensor.
hidden_dim: int
Expand All @@ -25,60 +22,83 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias):

super(ConvLSTMCell, self).__init__()

self.height, self.width = input_size
self.input_dim = input_dim
self.input_dim = input_dim
self.hidden_dim = hidden_dim

self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias

self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias)

def forward(self, input_tensor, cur_state):

h_cur, c_cur = cur_state

combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis

combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)

c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)

return h_next, c_next

def init_hidden(self, batch_size):
return (Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda(),
Variable(torch.zeros(batch_size, self.hidden_dim, self.height, self.width)).cuda())
def init_hidden(self, batch_size, image_size):
height, width = image_size
return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))


class ConvLSTM(nn.Module):

def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,
"""

Parameters:
input_dim: Number of channels in input
hidden_dim: Number of hidden channels
kernel_size: Size of kernel in convolutions
num_layers: Number of LSTM layers stacked on each other
batch_first: Whether or not dimension 0 is the batch or not
bias: Bias or no bias in Convolution
return_all_layers: Return the list of computations for all layers
Note: Will do same padding.

Input:
A tensor of size B, T, C, H, W or T, B, C, H, W
Output:
A tuple of two lists of length num_layers (or length 1 if return_all_layers is False).
0 - layer_output_list is the list of lists of length T of each output
1 - last_state_list is the list of last states
each element of the list is a tuple (h, c) for hidden state and memory
Example:
>> x = torch.rand((32, 10, 64, 128, 128))
>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
>> _, last_states = convlstm(x)
>> h = last_states[0][0] # 0 for layer index, 0 for h index
"""

def __init__(self, input_dim, hidden_dim, kernel_size, num_layers,
batch_first=False, bias=True, return_all_layers=False):
super(ConvLSTM, self).__init__()

self._check_kernel_size_consistency(kernel_size)

# Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
if not len(kernel_size) == len(hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')

self.height, self.width = input_size

self.input_dim = input_dim
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.num_layers = num_layers
Expand All @@ -88,10 +108,9 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,

cell_list = []
for i in range(0, self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i-1]
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]

cell_list.append(ConvLSTMCell(input_size=(self.height, self.width),
input_dim=cur_input_dim,
cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias))
Expand All @@ -100,14 +119,14 @@ def __init__(self, input_size, input_dim, hidden_dim, kernel_size, num_layers,

def forward(self, input_tensor, hidden_state=None):
"""

Parameters
----------
input_tensor: todo
input_tensor: todo
5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
hidden_state: todo
None. todo implement stateful

Returns
-------
last_state_list, layer_output
Expand All @@ -116,14 +135,18 @@ def forward(self, input_tensor, hidden_state=None):
# (t, b, c, h, w) -> (b, t, c, h, w)
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)

b, _, _, h, w = input_tensor.size()

# Implement stateful ConvLSTM
if hidden_state is not None:
raise NotImplementedError()
else:
hidden_state = self._init_hidden(batch_size=input_tensor.size(0))
# Since the init is done in forward. Can send image size here
hidden_state = self._init_hidden(batch_size=b,
image_size=(h, w))

layer_output_list = []
last_state_list = []
last_state_list = []

seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
Expand All @@ -133,7 +156,6 @@ def forward(self, input_tensor, hidden_state=None):
h, c = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):

h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :],
cur_state=[h, c])
output_inner.append(h)
Expand All @@ -146,20 +168,20 @@ def forward(self, input_tensor, hidden_state=None):

if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
last_state_list = last_state_list[-1:]

return layer_output_list, last_state_list

def _init_hidden(self, batch_size):
def _init_hidden(self, batch_size, image_size):
init_states = []
for i in range(self.num_layers):
init_states.append(self.cell_list[i].init_hidden(batch_size))
init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))
return init_states

@staticmethod
def _check_kernel_size_consistency(kernel_size):
if not (isinstance(kernel_size, tuple) or
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
(isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))):
raise ValueError('`kernel_size` must be tuple or list of tuples')

@staticmethod
Expand Down