forked from swathikirans/ego-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
1,722 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class MyConvLSTACell(nn.Module): | ||
def __init__(self, input_size, memory_size, c_cam_classes=100, kernel_size=3, | ||
stride=1, padding=1, zero_init=False): | ||
super(MyConvLSTACell, self).__init__() | ||
self.input_size = input_size | ||
self.memory_size = memory_size | ||
self.kernel_size = kernel_size | ||
self.stride = stride | ||
self.padding = padding | ||
self.c_classifier = nn.Linear(memory_size, c_cam_classes, bias=False) | ||
self.coupling_fc = nn.Linear(memory_size, c_cam_classes, bias=False) | ||
self.avgpool = nn.AvgPool2d(7) | ||
|
||
# Attention params | ||
|
||
self.conv_i_s = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_i_cam = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) | ||
|
||
self.conv_f_s = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_f_cam = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) | ||
|
||
self.conv_a_s = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_a_cam = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) | ||
|
||
self.conv_o_s = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_o_cam = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) | ||
|
||
if zero_init: | ||
torch.nn.init.constant_(self.conv_i_s.weight, 0) | ||
torch.nn.init.constant_(self.conv_i_s.bias, 0) | ||
torch.nn.init.constant_(self.conv_i_cam.weight, 0) | ||
|
||
torch.nn.init.constant_(self.conv_f_s.weight, 0) | ||
torch.nn.init.constant_(self.conv_f_s.bias, 0) | ||
torch.nn.init.constant_(self.conv_f_cam.weight, 0) | ||
|
||
torch.nn.init.constant_(self.conv_a_s.weight, 0) | ||
torch.nn.init.constant_(self.conv_a_s.bias, 0) | ||
|
||
torch.nn.init.constant_(self.conv_o_s.weight, 0) | ||
torch.nn.init.constant_(self.conv_o_s.bias, 0) | ||
torch.nn.init.constant_(self.conv_o_cam.weight, 0) | ||
else: | ||
torch.nn.init.xavier_normal_(self.conv_i_s.weight) | ||
torch.nn.init.constant_(self.conv_i_s.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_i_cam.weight) | ||
|
||
torch.nn.init.xavier_normal_(self.conv_f_s.weight) | ||
torch.nn.init.constant_(self.conv_f_s.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_f_cam.weight) | ||
|
||
torch.nn.init.xavier_normal_(self.conv_a_s.weight) | ||
torch.nn.init.constant_(self.conv_a_s.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_a_cam.weight) | ||
|
||
torch.nn.init.xavier_normal_(self.conv_o_s.weight) | ||
torch.nn.init.constant_(self.conv_o_s.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_o_cam.weight) | ||
|
||
# Memory params | ||
|
||
self.conv_i_x = nn.Conv2d(input_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_i_c = nn.Conv2d(memory_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding, | ||
bias=False) | ||
|
||
self.conv_f_x = nn.Conv2d(input_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_f_c = nn.Conv2d(memory_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding, | ||
bias=False) | ||
|
||
self.conv_c_x = nn.Conv2d(input_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_c_c = nn.Conv2d(memory_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding, | ||
bias=False) | ||
|
||
self.conv_o_x = nn.Conv2d(input_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding) | ||
self.conv_o_c = nn.Conv2d(memory_size, memory_size, kernel_size=kernel_size, stride=stride, padding=padding, | ||
bias=False) | ||
|
||
if zero_init: | ||
torch.nn.init.constant_(self.conv_i_x.weight, 0) | ||
torch.nn.init.constant_(self.conv_i_x.bias, 0) | ||
torch.nn.init.constant_(self.conv_i_c.weight, 0) | ||
|
||
torch.nn.init.constant_(self.conv_f_x.weight, 0) | ||
torch.nn.init.constant_(self.conv_f_x.bias, 0) | ||
torch.nn.init.constant_(self.conv_f_c.weight, 0) | ||
|
||
torch.nn.init.constant_(self.conv_c_x.weight, 0) | ||
torch.nn.init.constant_(self.conv_c_x.bias, 0) | ||
torch.nn.init.constant_(self.conv_c_c.weight, 0) | ||
|
||
torch.nn.init.constant_(self.conv_o_x.weight, 0) | ||
torch.nn.init.constant_(self.conv_o_x.bias, 0) | ||
torch.nn.init.constant_(self.conv_o_c.weight, 0) | ||
else: | ||
torch.nn.init.xavier_normal_(self.conv_i_x.weight) | ||
torch.nn.init.constant_(self.conv_i_x.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_i_c.weight) | ||
|
||
torch.nn.init.xavier_normal_(self.conv_f_x.weight) | ||
torch.nn.init.constant_(self.conv_f_x.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_f_c.weight) | ||
|
||
torch.nn.init.xavier_normal_(self.conv_c_x.weight) | ||
torch.nn.init.constant_(self.conv_c_x.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_c_c.weight) | ||
|
||
torch.nn.init.xavier_normal_(self.conv_o_x.weight) | ||
torch.nn.init.constant_(self.conv_o_x.bias, 0) | ||
torch.nn.init.xavier_normal_(self.conv_o_c.weight) | ||
|
||
def forward(self, x, cam, state_att, state_inp, x_flow_i=0, x_flow_f=0, x_flow_c=0, x_flow_o=0): | ||
# state_att = [a, s] | ||
# state_inp = [atanh(c), o] | ||
|
||
a_t_1 = state_att[0] | ||
s_t_1 = state_att[1] | ||
|
||
c_t_1 = torch.tanh(state_inp[0]) | ||
o_t_1 = state_inp[1] | ||
|
||
# Attention recurrence | ||
|
||
i_s = torch.sigmoid(self.conv_i_s(s_t_1) + self.conv_i_cam(cam)) | ||
f_s = torch.sigmoid(self.conv_f_s(s_t_1) + self.conv_f_cam(cam)) | ||
o_s = torch.sigmoid(self.conv_o_s(s_t_1) + self.conv_o_cam(cam)) | ||
a_tilde = torch.tanh(self.conv_a_s(s_t_1) + self.conv_a_cam(cam)) | ||
a = (f_s * a_t_1) + (i_s * a_tilde) | ||
s = o_s * torch.tanh(a) | ||
u = s + cam # hidden state + cam | ||
|
||
u = F.softmax(u.view(u.size(0), -1), 1) | ||
u = u.view(u.size(0), 1, 7, 7) | ||
|
||
x_att = x * u.expand_as(x) | ||
|
||
i_x = torch.sigmoid(self.conv_i_c(o_t_1 * c_t_1) + self.conv_i_x(x_att) + x_flow_i) | ||
f_x = torch.sigmoid(self.conv_f_c(o_t_1 * c_t_1) + self.conv_f_x(x_att) + x_flow_f) | ||
c_tilde = torch.tanh(self.conv_c_c(o_t_1 * c_t_1) + self.conv_c_x(x_att) + x_flow_c) | ||
c = (f_x * state_inp[0]) + (i_x * c_tilde) | ||
|
||
c_vec = self.avgpool(c).view(c.size(0), -1) | ||
c_logits = self.c_classifier(c_vec) + self.coupling_fc(self.avgpool(x_att).view(x_att.size(0), -1)) | ||
c_probs, c_idxs = c_logits.sort(1, True) | ||
c_class_idx = c_idxs[:, 0] | ||
c_cam = self.c_classifier.weight[c_class_idx].unsqueeze(2).unsqueeze(2) * c | ||
o_x = torch.sigmoid(self.conv_o_x(o_t_1 * c_t_1) + self.conv_o_c(c_cam)) | ||
|
||
state_att = [a, s] | ||
state_inp = [c, o_x] | ||
return state_att, state_inp, x_att |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# LSTA: Long Short-Term Attention for Egocentric Action Recognition | ||
|
||
|
||
We release the PyTorch code of [LSTA](https://arxiv.org/pdf/1811.10698.pdf) | ||
|
||
![LSTA](https://drive.google.com/uc?export=view&id=1gf9Ih_mK1xsd4ZVZvP7tsy4QJEkK1Dsz) | ||
|
||
|
||
#### Reference | ||
Please cite our paper if you find the repo and the paper useful. | ||
``` | ||
@InProceedings{Sudhakaran_2019_CVPR, | ||
author = {Sudhakaran, Swathikiran and Escalera, Sergio and Lanz, Oswald}, | ||
title = {{LSTA: Long Short-Term Attention for Egocentric Action Recognition}}, | ||
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
month = {June}, | ||
year = {2019} | ||
} | ||
``` | ||
|
||
#### Prerequisites | ||
|
||
* Python 3.5 | ||
* Pytorch 0.3.1 | ||
|
||
|
||
#### Training | ||
|
||
* ##### RGB | ||
To train the models, run the script train_rgb.sh, which contains: | ||
```` | ||
python main_rgb.py --dataset gtea_61 --root_dir dataset --outDir experiments --stage 1 \ | ||
--seqLen 25 --trainBatchSize 32 --numEpochs 200 --lr 0.001 --stepSize 25 75 150 \ | ||
--decayRate 0.1 --memSize 512 --outPoolSize 100 --evalInterval 5 --split 2 | ||
```` | ||
|
||
#### Evaluation | ||
Testing on the trained models can be done by running the script test_rgb.sh | ||
#### **Pretrained models** | ||
|
||
The pre-trained models can be downloaded from the following [Google Drive link](https://drive.google.com/drive/folders/1KIUuoaa1_ipGFOYZB6Oe3yITBKZlrpWr?usp=sharing) | ||
|
||
|
||
|
||
#### TODO | ||
1. EPIC-KITCHENS code | ||
2. Flow and two stream codes | ||
3. Pre-trained models | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import resNetNew | ||
from torch.autograd import Variable | ||
from MyConvLSTACell import * | ||
import torch | ||
|
||
|
||
class attentionModel(nn.Module): | ||
def __init__(self, num_classes=51, mem_size=512, c_cam_classes=1000): | ||
super(attentionModel, self).__init__() | ||
self.num_classes = num_classes | ||
self.resNet = resNetNew.resnet34(True, True) | ||
self.mem_size = mem_size | ||
self.lsta_cell = MyConvLSTACell(512, mem_size, c_cam_classes) | ||
self.avgpool = nn.AvgPool2d(7) | ||
self.dropout = nn.Dropout(0.7) | ||
self.fc = nn.Linear(mem_size, self.num_classes) | ||
self.classifier = nn.Sequential(self.dropout, self.fc) | ||
|
||
def forward(self, inputVariable, device): | ||
state_att = (torch.zeros(inputVariable.size(1), 1, 7, 7).to(device), | ||
torch.zeros(inputVariable.size(1), 1, 7, 7).to(device)) | ||
state_inp = ((torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).to(device)), | ||
(torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).to(device))) | ||
for t in range(inputVariable.size(0)): | ||
logit, feature_conv, x = self.resNet(inputVariable[t]) | ||
bz, nc, h, w = feature_conv.size() | ||
feature_conv1 = feature_conv.view(bz, nc, h * w) | ||
probs, idxs = logit.sort(1, True) | ||
class_idx = idxs[:, 0] | ||
cam = torch.bmm(self.resNet.fc.weight[class_idx].unsqueeze(1), feature_conv1).view(x.size(0), 1, 7, 7) | ||
state_att, state_inp, _ = self.lsta_cell(x, cam, state_att, state_inp) | ||
feats = self.avgpool(state_inp[0]).view(state_inp[0].size(0), -1) | ||
logits = self.classifier(feats) | ||
return logits, feats |
Oops, something went wrong.