Skip to content

Commit ccba5d3

Browse files
authored
Merge pull request #18 from guoyang9/acc_update
Fix a bug in accuracy computation and make apply_attention nicer.
2 parents 0942c8b + ac0f389 commit ccba5d3

File tree

2 files changed

+8
-21
lines changed

2 files changed

+8
-21
lines changed

model.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -117,31 +117,16 @@ def forward(self, v, q):
117117

118118

119119
def apply_attention(input, attention):
120-
""" Apply any number of attention maps over the input.
121-
The attention map has to have the same size in all dimensions except dim=1.
122-
"""
120+
""" Apply any number of attention maps over the input. """
123121
n, c = input.size()[:2]
124122
glimpses = attention.size(1)
125123

126124
# flatten the spatial dims into the third dim, since we don't need to care about how they are arranged
127-
input = input.view(n, c, -1)
125+
input = input.view(n, 1, c, -1) # [n, 1, c, s]
128126
attention = attention.view(n, glimpses, -1)
129-
s = input.size(2)
130-
131-
# apply a softmax to each attention map separately
132-
# since softmax only takes 2d inputs, we have to collapse the first two dimensions together
133-
# so that each glimpse is normalized separately
134-
attention = attention.view(n * glimpses, -1)
135-
attention = F.softmax(attention)
136-
137-
# apply the weighting by creating a new dim to tile both tensors over
138-
target_size = [n, glimpses, c, s]
139-
input = input.view(n, 1, c, s).expand(*target_size)
140-
attention = attention.view(n, glimpses, 1, s).expand(*target_size)
141-
weighted = input * attention
142-
# sum over only the spatial dimension
143-
weighted_mean = weighted.sum(dim=3)
144-
# the shape at this point is (n, glimpses, c, 1)
127+
attention = F.softmax(attention, dim=-1).unsqueeze(2) # [n, g, 1, s]
128+
weighted = attention * input # [n, g, v, s]
129+
weighted_mean = weighted.sum(dim=-1) # [n, g, v]
145130
return weighted_mean.view(n, -1)
146131

147132

train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def run(net, loader, optimizer, tracker, train=False, prefix='', epoch=0):
7474
idxs.append(idx.view(-1).clone())
7575

7676
loss_tracker.append(loss.data[0])
77-
acc_tracker.append(acc.mean())
77+
# acc_tracker.append(acc.mean())
78+
for a in acc:
79+
acc_tracker.append(a.item())
7880
fmt = '{:.4f}'.format
7981
tq.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value))
8082

0 commit comments

Comments
 (0)