File tree Expand file tree Collapse file tree 2 files changed +2
-2
lines changed
Expand file tree Collapse file tree 2 files changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -48,7 +48,7 @@ def __init__(self, embedding_tokens):
4848 def forward (self , v , q , q_len ):
4949 q = self .text (q , list (q_len .data ))
5050
51- v = v / (v .norm (p = 2 , dim = 1 ).expand_as (v ) + 1e-8 )
51+ v = v / (v .norm (p = 2 , dim = 1 , keepdim = True ).expand_as (v ) + 1e-8 )
5252 a = self .attention (v , q )
5353 v = apply_attention (v , a )
5454
Original file line number Diff line number Diff line change 1010
1111def batch_accuracy (predicted , true ):
1212 """ Compute the accuracies for a batch of predictions and answers """
13- _ , predicted_index = predicted .max (dim = 1 )
13+ _ , predicted_index = predicted .max (dim = 1 , keepdim = True )
1414 agreeing = true .gather (dim = 1 , index = predicted_index )
1515 '''
1616 Acc needs to be averaged over all 10 choose 9 subsets of human answers.
You can’t perform that action at this time.
0 commit comments