@@ -117,31 +117,16 @@ def forward(self, v, q):
117117
118118
119119def apply_attention (input , attention ):
120- """ Apply any number of attention maps over the input.
121- The attention map has to have the same size in all dimensions except dim=1.
122- """
120+ """ Apply any number of attention maps over the input. """
123121 n , c = input .size ()[:2 ]
124122 glimpses = attention .size (1 )
125123
126124 # flatten the spatial dims into the third dim, since we don't need to care about how they are arranged
127- input = input .view (n , c , - 1 )
125+ input = input .view (n , 1 , c , - 1 ) # [n, 1, c, s]
128126 attention = attention .view (n , glimpses , - 1 )
129- s = input .size (2 )
130-
131- # apply a softmax to each attention map separately
132- # since softmax only takes 2d inputs, we have to collapse the first two dimensions together
133- # so that each glimpse is normalized separately
134- attention = attention .view (n * glimpses , - 1 )
135- attention = F .softmax (attention )
136-
137- # apply the weighting by creating a new dim to tile both tensors over
138- target_size = [n , glimpses , c , s ]
139- input = input .view (n , 1 , c , s ).expand (* target_size )
140- attention = attention .view (n , glimpses , 1 , s ).expand (* target_size )
141- weighted = input * attention
142- # sum over only the spatial dimension
143- weighted_mean = weighted .sum (dim = 3 )
144- # the shape at this point is (n, glimpses, c, 1)
127+ attention = F .softmax (attention , dim = - 1 ).unsqueeze (2 ) # [n, g, 1, s]
128+ weighted = attention * input # [n, g, v, s]
129+ weighted_mean = weighted .sum (dim = - 1 ) # [n, g, v]
145130 return weighted_mean .view (n , - 1 )
146131
147132
0 commit comments