-
Notifications
You must be signed in to change notification settings - Fork 1
/
Attention.py
77 lines (63 loc) · 2.79 KB
/
Attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch.nn as nn
import torch
from torch.autograd import Variable
import torch.nn.functional as F
class Attention_average(nn.Module):
def __init__(self, sequence, img_dim, kernel_size):
super(Attention_average, self).__init__()
self.sequence = sequence
self.img_dim = img_dim
self.kernel_size = kernel_size
def forward(self, x):
output = self.pooling(x).view(-1, self.sequence, self.img_dim)
return output
def pooling(self, x):
output = torch.mean(torch.mean(x, dim=3), dim=2)
return output
class Attentnion_auto(nn.Module):
def __init__(self, sequence, img_dim, kernel_size,):
super(Attentnion_auto, self).__init__()
self.sequence = sequence
self.img_dim = img_dim
self.kernel_size = kernel_size
self.conv = nn.Conv2d(1, 1, kernel_size=1)
def forward(self, x):
feature_pow = torch.pow(x, 2)
feature_map = torch.mean(feature_pow, dim=1).view(-1, 1, self.kernel_size, self.kernel_size)
feature_map = self.conv(feature_map).view(-1, self.kernel_size ** 2)
feature_weight = F.softmax(feature_map, dim=-1).view(-1, 1, self.kernel_size, self.kernel_size).expand_as(x)
out_map = feature_weight * x
output = torch.sum(torch.sum(out_map, dim=3), dim=2)
return output.view(-1, self.sequence, self.img_dim)
class Attention_learned(nn.Module):
def __init__(self, sequence, img_dim, kernel_size, bottle_neck=128):
super(Attention_learned, self).__init__()
self.kernel_size = kernel_size
self.im_dim = img_dim
self.sequence = sequence
# self.alpha = torch.nn.Parameter(torch.zeros(1), requires_grad=True)
self.linear = nn.Sequential(
nn.Linear(self.im_dim, bottle_neck),
nn.Tanh(),
nn.Linear(bottle_neck, 1),
nn.Tanh(),
)
self.conv = nn.Sequential(
nn.Conv1d(self.kernel_size ** 2, self.kernel_size ** 2, 1),
# nn.Sigmoid(),
)
def forward(self, outhigh):
outhigh = outhigh.view(-1, self.im_dim, self.kernel_size * self.kernel_size).transpose(1, 2)
weight = self.linear(outhigh).squeeze(-1)
attention = F.softmax(weight, dim=-1).unsqueeze(-1)
attention_data = outhigh * attention
descriptor = torch.sum(attention_data, dim=1)
return descriptor.view(-1, self.sequence, self.im_dim)
if __name__ == '__main__':
fake_data = Variable(torch.randn(24, 512, 7, 7))
net1 = Attention_average(sequence=12, img_dim=512, kernel_size=7)
net2 = Attentnion_auto(sequence=12, img_dim=512, kernel_size=7)
net3 = Attention_learned(sequence=12, img_dim=512, kernel_size=7)
print(net1(fake_data).size())
print(net2(fake_data).size())
print(net3(fake_data).size())