-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
8,210 additions
and
0 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch.nn as nn | ||
import torch | ||
from torch.autograd import Variable | ||
import torch.nn.functional as F | ||
|
||
class Attention_average(nn.Module): | ||
def __init__(self, sequence, img_dim, kernel_size): | ||
super(Attention_average, self).__init__() | ||
self.sequence = sequence | ||
self.img_dim = img_dim | ||
self.kernel_size = kernel_size | ||
|
||
def forward(self, x): | ||
output = self.pooling(x).view(-1, self.sequence, self.img_dim) | ||
return output | ||
|
||
def pooling(self, x): | ||
output = torch.mean(torch.mean(x, dim=3), dim=2) | ||
return output | ||
|
||
class Attentnion_auto(nn.Module): | ||
def __init__(self, sequence, img_dim, kernel_size,): | ||
super(Attentnion_auto, self).__init__() | ||
self.sequence = sequence | ||
self.img_dim = img_dim | ||
self.kernel_size = kernel_size | ||
self.conv = nn.Conv2d(1, 1, kernel_size=1) | ||
|
||
def forward(self, x): | ||
feature_pow = torch.pow(x, 2) | ||
feature_map = torch.mean(feature_pow, dim=1).view(-1, 1, self.kernel_size, self.kernel_size) | ||
feature_map = self.conv(feature_map).view(-1, self.kernel_size ** 2) | ||
feature_weight = F.softmax(feature_map, dim=-1).view(-1, 1, self.kernel_size, self.kernel_size).expand_as(x) | ||
out_map = feature_weight * x | ||
output = torch.sum(torch.sum(out_map, dim=3), dim=2) | ||
|
||
return output.view(-1, self.sequence, self.img_dim) | ||
|
||
class Attention_learned(nn.Module): | ||
def __init__(self, sequence, img_dim, kernel_size, bottle_neck=128): | ||
super(Attention_learned, self).__init__() | ||
self.kernel_size = kernel_size | ||
self.im_dim = img_dim | ||
self.sequence = sequence | ||
# self.alpha = torch.nn.Parameter(torch.zeros(1), requires_grad=True) | ||
|
||
self.linear = nn.Sequential( | ||
nn.Linear(self.im_dim, bottle_neck), | ||
nn.Tanh(), | ||
nn.Linear(bottle_neck, 1), | ||
nn.Tanh(), | ||
) | ||
self.conv = nn.Sequential( | ||
nn.Conv1d(self.kernel_size ** 2, self.kernel_size ** 2, 1), | ||
# nn.Sigmoid(), | ||
) | ||
|
||
|
||
def forward(self, outhigh): | ||
outhigh = outhigh.view(-1, self.im_dim, self.kernel_size * self.kernel_size).transpose(1, 2) | ||
weight = self.linear(outhigh).squeeze(-1) | ||
attention = F.softmax(weight, dim=-1).unsqueeze(-1) | ||
attention_data = outhigh * attention | ||
descriptor = torch.sum(attention_data, dim=1) | ||
|
||
return descriptor.view(-1, self.sequence, self.im_dim) | ||
|
||
|
||
if __name__ == '__main__': | ||
fake_data = Variable(torch.randn(24, 512, 7, 7)) | ||
net1 = Attention_average(sequence=12, img_dim=512, kernel_size=7) | ||
net2 = Attentnion_auto(sequence=12, img_dim=512, kernel_size=7) | ||
net3 = Attention_learned(sequence=12, img_dim=512, kernel_size=7) | ||
print(net1(fake_data).size()) | ||
print(net2(fake_data).size()) | ||
print(net3(fake_data).size()) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch.utils.data as data | ||
import numpy as np | ||
import os | ||
from PIL import Image | ||
import torch | ||
|
||
# 根据txt文档读取视频及参数 | ||
def read_data(lines): | ||
list_strs = lines.split(' ') | ||
num_frame = list_strs[-2] | ||
label = list_strs[-1] | ||
path = list_strs[0] | ||
for idx in range(1, len(list_strs)-2): | ||
path = path + ' ' + list_strs[idx] | ||
|
||
return [path, num_frame, label] | ||
|
||
# 获得每个视频的参数 | ||
class Videolist_Parse(object): | ||
'''将每个video的参数保留下来包含[path,num_frames_label]''' | ||
def __init__(self, row): | ||
self.row = row | ||
@property | ||
def path(self): | ||
return self.row[0] | ||
@property | ||
def num_frames(self): | ||
return int(self.row[1]) | ||
@property | ||
def label(self): | ||
return int(self.row[2]) | ||
|
||
# 获得Dataset数据类型的数据 | ||
class VideoDataset(data.Dataset): | ||
def __init__(self, root, list, transform, num_segments, num_frames, test_mode=False): | ||
self.transform = transform | ||
self.list = list | ||
self.root = root | ||
self.num_segments = num_segments | ||
self.num_frames = num_frames | ||
self.test_mode = test_mode | ||
|
||
self._parse_videolist() | ||
|
||
def __len__(self): | ||
return len(self.videolist) | ||
|
||
# 返回图像的tensor和标签 | ||
def __getitem__(self, idx): | ||
record = self.videolist[idx] | ||
|
||
if not self.test_mode: | ||
indices = self.get_indices(record) | ||
image_tensor = self.get_img(indices, record) | ||
else: | ||
image_tensor = [] | ||
for count in range(10): | ||
indices = self.get_indices(record) | ||
image_tensor.append(self.get_img(indices, record)) | ||
image_tensor = torch.stack(image_tensor, dim=0) | ||
return image_tensor, record.label | ||
|
||
# 获得所有视频的参数列表 | ||
def _parse_videolist(self): | ||
'''获得video列表包含每个video的参数[path,num_frames_label] | ||
保存在videolist中 | ||
''' | ||
lines = [read_data(x.strip()) for x in open(self.root + self.list)] | ||
self.videolist = [Videolist_Parse(item) for item in lines] | ||
|
||
def get_indices(self, record): | ||
# average_duration表示某个视频分成self.num_segments份的时候每一份包含多少帧图像 | ||
average_duration = record.num_frames // self.num_segments | ||
# 生成了self.num_segments个范围在0到average_duration的数值,二者相加就相当于在这self.num_segments个片段中分别随机选择了一帧图像。 | ||
|
||
choices = [np.random.choice(average_duration, self.num_frames, replace=False)+ i * average_duration for i in range(self.num_segments)] | ||
choices = np.concatenate((choices)) | ||
offsets = np.sort(choices) | ||
return offsets | ||
|
||
# 根据视频的indices获取视频帧图像 | ||
def get_img(self, indices, record): | ||
frames = torch.zeros(self.num_segments * self.num_frames, 3, 224, 224) | ||
for idx, idx_img in enumerate(indices): | ||
dir_img = os.path.join(self.root, record.path, str(idx_img+1)+'.jpg') | ||
image = Image.open(dir_img).convert('RGB') | ||
frames[idx] = self.transform(image) | ||
return frames |
Oops, something went wrong.