Skip to content

Commit ff69738

Browse files
added attention sum
1 parent 7eedef7 commit ff69738

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

model.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
5+
6+
7+
class GRU(nn.Module):
8+
def __init__(self,input_size,hidden_size):
9+
super(GRU,self).__init__()
10+
11+
self.gru=nn.GRU(input_size,hidden_size,bidirectional=True,batch_first=True)
12+
13+
def forward(self,input,input_mask):
14+
seq_len=torch.sum(input_mask,dim=-1)
15+
sorted_len,sorted_index=seq_len.sort(0,descending=True)
16+
i_sorted_index=sorted_index.view(-1,1,1).expand_as(input)
17+
sorted_input=input.gather(0,i_sorted_index.long())
18+
19+
packed_seq=pack_padded_sequence(sorted_input,sorted_len,batch_first=True)
20+
output,(hidden,cell_state)=self.gru(packed_seq)
21+
unpacked_seq,unpacked_len=pad_packed_sequence(output,batch_first=True)
22+
23+
_,original_index=sorted_index.sort(0,descending=False)
24+
unsorted_index=original_index.view(-1,1,1).expand_as(unpacked_seq)
25+
output_final=unpacked_seq.gather(0,unsorted_index.long())
26+
27+
return output_final,seq_len
28+
29+
30+
class Char_Embeds(nn.Module):
31+
def __init__(self,n_chars,char_size,embed_size,hidden_size):
32+
super(Char_Embeds,self).__init__()
33+
self.hidden_size=hidden_size
34+
self.embed_size=embed_size
35+
36+
self.char_embedding=nn.Embedding(n_chars,char_size)
37+
self.forward_project=nn.Linear(hidden_size,embed_size//2)
38+
self.backward_project=nn.Linear(hidden_size,embed_size//2)
39+
40+
self.gru=GRU(char_size,hidden_size)
41+
42+
def forward(self,input,mask,doc_char,query_char):
43+
input=self.char_embedding(input)
44+
input,seq_len=self.gru(input,mask)
45+
46+
final_index=(seq_len-1).view(-1,1).expand(input.size(0),input.size(2)).unsqueeze(1)
47+
output=input.gather(1,final_index.long()).squeeze()
48+
49+
forward_output=output[:,:self.hidden_size]
50+
backward_output=output[:,self.hidden_size:]
51+
forward_output=self.forward_project(forward_output)
52+
backward_output=self.backward_project(backward_output)
53+
final=forward_output+backward_output
54+
55+
doc_embed=final.index_select(0,doc_char.view(-1)).view(doc_char.shape[0],
56+
doc_char.shape[1],self.embed_size//2)
57+
query_embed=final.index_select(0,query_char.view(-1)).view(query_char.shape[0],
58+
query_char.shape[1],self.embed_size//2)
59+
60+
return doc_embed,query_embed
61+
62+
63+
class GA_Reader(nn.Module):
64+
def __init__(self,n_chars,char_size,embed_size,hidden_size_char,hidden_size,
65+
vocab_size,pretrained_weights,gru_layers,use_features,use_chars):
66+
super(GA_Reader,self).__init__()
67+
self.embedding=nn.Embedding.from_pretrained(pretrained_weights)
68+
self.use_chars=use_chars
69+
self.use_features=use_features
70+
self.gru_layers=gru_layers
71+
72+
self.grus_docs=nn.ModuleList()
73+
self.grus_query=nn.ModuleList()
74+
for i in range(gru_layers-1):
75+
if i==0:
76+
if self.use_chars:
77+
G1=GRU(3*embed_size//2,hidden_size)
78+
else:
79+
G1=GRU(embed_size,hidden_size)
80+
else:
81+
G1=GRU(2*hidden_size,hidden_size)
82+
if self.use_chars:
83+
G2=GRU(3*embed_size//2,hidden_size)
84+
else:
85+
G2=GRU(embed_size,hidden_size)
86+
self.grus_docs.append(G1)
87+
self.grus_query.append(G2)
88+
89+
if use_features:
90+
self.features=nn.Embedding(2,2)
91+
self.finalgru_doc=GRU(2*hidden_size+use_features*2,hidden_size)
92+
self.finalgru_query=GRU(3*embed_size//2,hidden_size)
93+
94+
if use_chars:
95+
self.char_embeds=Char_Embeds(n_chars,char_size,embed_size,hidden_size_char)
96+
97+
def forward(self,doc,doc_char,doc_mask,query,query_char,query_mask,
98+
char_type,char_type_mask,ans,cloze,cands,cand_mask,qe_comm):
99+
doc_embed=self.embedding(doc)
100+
query_embed=self.embedding(query)
101+
102+
if self.use_chars:
103+
doc_char_embed,query_char_embed=self.char_embeds(char_type,char_type_mask,doc_char,query_char)
104+
doc_embed=torch.cat([doc_embed,doc_char_embed],dim=-1)
105+
query_embed=torch.cat([query_embed,query_char_embed],dim=-1)
106+
107+
for i in range(self.gru_layers-1):
108+
doc_D,_=self.grus_docs[i](doc_embed,doc_mask)
109+
Q,_=self.grus_query[i](query_embed,query_mask)
110+
111+
doc_embed=self.attention(doc_D,Q,doc_mask,query_mask)
112+
113+
if self.use_features:
114+
features=self.features(qe_comm)
115+
D=torch.cat([doc_embed,features],dim=-1)
116+
117+
final_doc,_=self.finalgru_doc(D,doc_mask)
118+
final_query,_=self.finalgru_query(query_embed,query_mask)
119+
output=self.attention_sum(final_doc,final_query,cloze,cands,cand_mask)
120+
121+
return output
122+
123+
def attention(self,D,Q,doc_mask,query_mask):
124+
mask_Q=query_mask.unsqueeze(1).expand(-1,D.shape[1],-1)
125+
mask_D=doc_mask.unsqueeze(-1).expand(-1,-1,Q.shape[1])
126+
attn=F.softmax(torch.bmm(D,Q.transpose(-1,-2)),dim=-1)*mask_Q*mask_D
127+
128+
weights=torch.bmm(attn,Q)
129+
output=weights*D
130+
131+
return output
132+
133+
def attention_sum(self,doc,query,cloze,cand,cand_mask):
134+
mask=cloze.view(-1,1).unsqueeze(-1).expand(-1,query.shape[1],query.shape[-1])
135+
q=query.gather(1,mask)
136+
q=q[:,0,:].view(query.shape[0],-1,1)
137+
distribution=torch.bmm(doc,q).squeeze()
138+
139+
probs=F.softmax(distribution,dim=-1)*cand_mask
140+
output=torch.bmm(probs.unsqueeze(1),cand.float()).squeeze()
141+
142+
return output

0 commit comments

Comments
 (0)