Skip to content

Commit 5a0365b

Browse files
committed
Add keepdim=True for PyTorch version upgrade to 0.2.0
1 parent 3193768 commit 5a0365b

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, embedding_tokens):
4848
def forward(self, v, q, q_len):
4949
q = self.text(q, list(q_len.data))
5050

51-
v = v / (v.norm(p=2, dim=1).expand_as(v) + 1e-8)
51+
v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8)
5252
a = self.attention(v, q)
5353
v = apply_attention(v, a)
5454

utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def batch_accuracy(predicted, true):
1212
""" Compute the accuracies for a batch of predictions and answers """
13-
_, predicted_index = predicted.max(dim=1)
13+
_, predicted_index = predicted.max(dim=1, keepdim=True)
1414
agreeing = true.gather(dim=1, index=predicted_index)
1515
'''
1616
Acc needs to be averaged over all 10 choose 9 subsets of human answers.

0 commit comments

Comments
 (0)