Skip to content

Commit aa280e6

Browse files
corect attention
1 parent 4d7d921 commit aa280e6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def forward(self,doc,doc_char,doc_mask,query,query_char,query_mask,
123123
def attention(self,D,Q,doc_mask,query_mask):
124124
mask_Q=query_mask.unsqueeze(1).expand(-1,D.shape[1],-1)
125125
mask_D=doc_mask.unsqueeze(-1).expand(-1,-1,Q.shape[1])
126-
attn=F.softmax(torch.bmm(D,Q.transpose(-1,-2)),dim=-1)*mask_Q*mask_D
126+
attn_temp=torch.bmm(D,Q.transpose(-1,-2))
127+
attn_temp=attn_temp+(1-mask_Q)*1e-9+(1-mask_D)*1e-9
128+
attn=F.softmax(attn_temp,dim=-1)
127129

128130
weights=torch.bmm(attn,Q)
129131
output=weights*D

0 commit comments

Comments
 (0)