-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
1,175 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
# Spatio-Temporal LSTM | ||
# Demo code for (ECCV 2016) Spatio-Temporal LSTM with Trust Gates for 3D Human Action Recognition | ||
# Unfinished ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
|
||
local STLSTM = {} | ||
function STLSTM.stlstm(input_size, output_size, rnn_size, n, dropout) -- n: num_layers | ||
dropout = dropout or 0 | ||
|
||
-- there will be 4*n+1 inputs | ||
local inputs = {} | ||
table.insert(inputs, nn.Identity()()) -- x | ||
for L = 1, n do | ||
table.insert(inputs, nn.Identity()()) -- prev_cj[L] | ||
table.insert(inputs, nn.Identity()()) -- prev_hj[L] | ||
end | ||
for L = 1, n do | ||
table.insert(inputs, nn.Identity()()) -- prev_ct[L] | ||
table.insert(inputs, nn.Identity()()) -- prev_ht[L] | ||
end | ||
|
||
local x, input_size_L | ||
local outputs = {} | ||
|
||
for L = 1, n do | ||
-- c,h from previos steps | ||
local prev_cj = inputs[L*2] | ||
local prev_hj = inputs[L*2+1] | ||
|
||
local prev_ct = inputs[n*2+L*2] | ||
local prev_ht = inputs[n*2+L*2+1] | ||
|
||
-- the input to this layer | ||
if (L == 1) then | ||
x = inputs[1] | ||
input_size_L = input_size | ||
else | ||
x = outputs[(L-1)*2] | ||
if dropout > 0 then x = nn.Dropout(dropout)(x) end -- apply dropout, if any | ||
input_size_L = rnn_size | ||
end | ||
|
||
-- evaluate the input sums at once for efficiency | ||
local i2h = nn.Linear(input_size_L, 5 * rnn_size)(x):annotate{ name = 'i2h_' .. L} | ||
local h2hj = nn.Linear(rnn_size, 5 * rnn_size)(prev_hj):annotate{name = 'h2hj_' .. L} | ||
local h2ht = nn.Linear(rnn_size, 5 * rnn_size)(prev_ht):annotate{name = 'h2ht_' .. L} | ||
local all_input_sums = nn.CAddTable()({i2h, h2hj, h2ht}) | ||
|
||
local reshaped = nn.Reshape(5, rnn_size)(all_input_sums) | ||
local n1, n2, n3, n4, n5 = nn.SplitTable(2)(reshaped):split(5) | ||
|
||
-- decode the gates | ||
local in_gate = nn.Sigmoid()(n1) | ||
local forget_gate_j = nn.Sigmoid()(n2) | ||
local forget_gate_t = nn.Sigmoid()(n3) | ||
local out_gate = nn.Sigmoid()(n4) | ||
-- decode the write inputs | ||
local in_transform = nn.Tanh()(n5) | ||
-- perform the STLSTM update | ||
local next_c = nn.CAddTable()({ | ||
nn.CMulTable()({forget_gate_j, prev_cj}), | ||
nn.CMulTable()({forget_gate_t, prev_ct}), | ||
nn.CMulTable()({in_gate, in_transform}) }) | ||
|
||
-- gated cells form the output | ||
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) | ||
|
||
table.insert(outputs, next_c) | ||
table.insert(outputs, next_h) | ||
end | ||
|
||
-- set up the decoder | ||
local top_h = outputs[#outputs] | ||
if dropout > 0 then top_h = nn.Dropout(dropout)(top_h) end | ||
local proj = nn.Linear(rnn_size, output_size)(top_h):annotate{name='decoder'} | ||
local logsoft = nn.LogSoftMax()(proj) | ||
table.insert(outputs, logsoft) | ||
|
||
return nn.gModule(inputs, outputs) | ||
end | ||
|
||
return STLSTM | ||
|
Oops, something went wrong.