Skip to content

Commit

Permalink
HMY
Browse files Browse the repository at this point in the history
  • Loading branch information
394481125 committed Apr 26, 2021
1 parent e94bdd4 commit 77456cb
Show file tree
Hide file tree
Showing 11 changed files with 8,210 additions and 0 deletions.
12 changes: 12 additions & 0 deletions .idea/attention-network-master.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/encodings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

746 changes: 746 additions & 0 deletions .idea/workspace.xml

Large diffs are not rendered by default.

77 changes: 77 additions & 0 deletions Attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.nn.functional as F

class Attention_average(nn.Module):
def __init__(self, sequence, img_dim, kernel_size):
super(Attention_average, self).__init__()
self.sequence = sequence
self.img_dim = img_dim
self.kernel_size = kernel_size

def forward(self, x):
output = self.pooling(x).view(-1, self.sequence, self.img_dim)
return output

def pooling(self, x):
output = torch.mean(torch.mean(x, dim=3), dim=2)
return output

class Attentnion_auto(nn.Module):
def __init__(self, sequence, img_dim, kernel_size,):
super(Attentnion_auto, self).__init__()
self.sequence = sequence
self.img_dim = img_dim
self.kernel_size = kernel_size
self.conv = nn.Conv2d(1, 1, kernel_size=1)

def forward(self, x):
feature_pow = torch.pow(x, 2)
feature_map = torch.mean(feature_pow, dim=1).view(-1, 1, self.kernel_size, self.kernel_size)
feature_map = self.conv(feature_map).view(-1, self.kernel_size ** 2)
feature_weight = F.softmax(feature_map, dim=-1).view(-1, 1, self.kernel_size, self.kernel_size).expand_as(x)
out_map = feature_weight * x
output = torch.sum(torch.sum(out_map, dim=3), dim=2)

return output.view(-1, self.sequence, self.img_dim)

class Attention_learned(nn.Module):
def __init__(self, sequence, img_dim, kernel_size, bottle_neck=128):
super(Attention_learned, self).__init__()
self.kernel_size = kernel_size
self.im_dim = img_dim
self.sequence = sequence
# self.alpha = torch.nn.Parameter(torch.zeros(1), requires_grad=True)

self.linear = nn.Sequential(
nn.Linear(self.im_dim, bottle_neck),
nn.Tanh(),
nn.Linear(bottle_neck, 1),
nn.Tanh(),
)
self.conv = nn.Sequential(
nn.Conv1d(self.kernel_size ** 2, self.kernel_size ** 2, 1),
# nn.Sigmoid(),
)


def forward(self, outhigh):
outhigh = outhigh.view(-1, self.im_dim, self.kernel_size * self.kernel_size).transpose(1, 2)
weight = self.linear(outhigh).squeeze(-1)
attention = F.softmax(weight, dim=-1).unsqueeze(-1)
attention_data = outhigh * attention
descriptor = torch.sum(attention_data, dim=1)

return descriptor.view(-1, self.sequence, self.im_dim)


if __name__ == '__main__':
fake_data = Variable(torch.randn(24, 512, 7, 7))
net1 = Attention_average(sequence=12, img_dim=512, kernel_size=7)
net2 = Attentnion_auto(sequence=12, img_dim=512, kernel_size=7)
net3 = Attention_learned(sequence=12, img_dim=512, kernel_size=7)
print(net1(fake_data).size())
print(net2(fake_data).size())
print(net3(fake_data).size())

88 changes: 88 additions & 0 deletions Dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch.utils.data as data
import numpy as np
import os
from PIL import Image
import torch

# 根据txt文档读取视频及参数
def read_data(lines):
list_strs = lines.split(' ')
num_frame = list_strs[-2]
label = list_strs[-1]
path = list_strs[0]
for idx in range(1, len(list_strs)-2):
path = path + ' ' + list_strs[idx]

return [path, num_frame, label]

# 获得每个视频的参数
class Videolist_Parse(object):
'''将每个video的参数保留下来包含[path,num_frames_label]'''
def __init__(self, row):
self.row = row
@property
def path(self):
return self.row[0]
@property
def num_frames(self):
return int(self.row[1])
@property
def label(self):
return int(self.row[2])

# 获得Dataset数据类型的数据
class VideoDataset(data.Dataset):
def __init__(self, root, list, transform, num_segments, num_frames, test_mode=False):
self.transform = transform
self.list = list
self.root = root
self.num_segments = num_segments
self.num_frames = num_frames
self.test_mode = test_mode

self._parse_videolist()

def __len__(self):
return len(self.videolist)

# 返回图像的tensor和标签
def __getitem__(self, idx):
record = self.videolist[idx]

if not self.test_mode:
indices = self.get_indices(record)
image_tensor = self.get_img(indices, record)
else:
image_tensor = []
for count in range(10):
indices = self.get_indices(record)
image_tensor.append(self.get_img(indices, record))
image_tensor = torch.stack(image_tensor, dim=0)
return image_tensor, record.label

# 获得所有视频的参数列表
def _parse_videolist(self):
'''获得video列表包含每个video的参数[path,num_frames_label]
保存在videolist中
'''
lines = [read_data(x.strip()) for x in open(self.root + self.list)]
self.videolist = [Videolist_Parse(item) for item in lines]

def get_indices(self, record):
# average_duration表示某个视频分成self.num_segments份的时候每一份包含多少帧图像
average_duration = record.num_frames // self.num_segments
# 生成了self.num_segments个范围在0到average_duration的数值,二者相加就相当于在这self.num_segments个片段中分别随机选择了一帧图像。

choices = [np.random.choice(average_duration, self.num_frames, replace=False)+ i * average_duration for i in range(self.num_segments)]
choices = np.concatenate((choices))
offsets = np.sort(choices)
return offsets

# 根据视频的indices获取视频帧图像
def get_img(self, indices, record):
frames = torch.zeros(self.num_segments * self.num_frames, 3, 224, 224)
for idx, idx_img in enumerate(indices):
dir_img = os.path.join(self.root, record.path, str(idx_img+1)+'.jpg')
image = Image.open(dir_img).convert('RGB')
frames[idx] = self.transform(image)
return frames
Loading

0 comments on commit 77456cb

Please sign in to comment.