1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .nn .functional as F
4
+ from torch .nn .utils .rnn import pack_padded_sequence ,pad_packed_sequence
5
+
6
+
7
+ class GRU (nn .Module ):
8
+ def __init__ (self ,input_size ,hidden_size ):
9
+ super (GRU ,self ).__init__ ()
10
+
11
+ self .gru = nn .GRU (input_size ,hidden_size ,bidirectional = True ,batch_first = True )
12
+
13
+ def forward (self ,input ,input_mask ):
14
+ seq_len = torch .sum (input_mask ,dim = - 1 )
15
+ sorted_len ,sorted_index = seq_len .sort (0 ,descending = True )
16
+ i_sorted_index = sorted_index .view (- 1 ,1 ,1 ).expand_as (input )
17
+ sorted_input = input .gather (0 ,i_sorted_index .long ())
18
+
19
+ packed_seq = pack_padded_sequence (sorted_input ,sorted_len ,batch_first = True )
20
+ output ,(hidden ,cell_state )= self .gru (packed_seq )
21
+ unpacked_seq ,unpacked_len = pad_packed_sequence (output ,batch_first = True )
22
+
23
+ _ ,original_index = sorted_index .sort (0 ,descending = False )
24
+ unsorted_index = original_index .view (- 1 ,1 ,1 ).expand_as (unpacked_seq )
25
+ output_final = unpacked_seq .gather (0 ,unsorted_index .long ())
26
+
27
+ return output_final ,seq_len
28
+
29
+
30
+ class Char_Embeds (nn .Module ):
31
+ def __init__ (self ,n_chars ,char_size ,embed_size ,hidden_size ):
32
+ super (Char_Embeds ,self ).__init__ ()
33
+ self .hidden_size = hidden_size
34
+ self .embed_size = embed_size
35
+
36
+ self .char_embedding = nn .Embedding (n_chars ,char_size )
37
+ self .forward_project = nn .Linear (hidden_size ,embed_size // 2 )
38
+ self .backward_project = nn .Linear (hidden_size ,embed_size // 2 )
39
+
40
+ self .gru = GRU (char_size ,hidden_size )
41
+
42
+ def forward (self ,input ,mask ,doc_char ,query_char ):
43
+ input = self .char_embedding (input )
44
+ input ,seq_len = self .gru (input ,mask )
45
+
46
+ final_index = (seq_len - 1 ).view (- 1 ,1 ).expand (input .size (0 ),input .size (2 )).unsqueeze (1 )
47
+ output = input .gather (1 ,final_index .long ()).squeeze ()
48
+
49
+ forward_output = output [:,:self .hidden_size ]
50
+ backward_output = output [:,self .hidden_size :]
51
+ forward_output = self .forward_project (forward_output )
52
+ backward_output = self .backward_project (backward_output )
53
+ final = forward_output + backward_output
54
+
55
+ doc_embed = final .index_select (0 ,doc_char .view (- 1 )).view (doc_char .shape [0 ],
56
+ doc_char .shape [1 ],self .embed_size // 2 )
57
+ query_embed = final .index_select (0 ,query_char .view (- 1 )).view (query_char .shape [0 ],
58
+ query_char .shape [1 ],self .embed_size // 2 )
59
+
60
+ return doc_embed ,query_embed
61
+
62
+
63
+ class GA_Reader (nn .Module ):
64
+ def __init__ (self ,n_chars ,char_size ,embed_size ,hidden_size_char ,hidden_size ,
65
+ vocab_size ,pretrained_weights ,gru_layers ,use_features ,use_chars ):
66
+ super (GA_Reader ,self ).__init__ ()
67
+ self .embedding = nn .Embedding .from_pretrained (pretrained_weights )
68
+ self .use_chars = use_chars
69
+ self .use_features = use_features
70
+ self .gru_layers = gru_layers
71
+
72
+ self .grus_docs = nn .ModuleList ()
73
+ self .grus_query = nn .ModuleList ()
74
+ for i in range (gru_layers - 1 ):
75
+ if i == 0 :
76
+ if self .use_chars :
77
+ G1 = GRU (3 * embed_size // 2 ,hidden_size )
78
+ else :
79
+ G1 = GRU (embed_size ,hidden_size )
80
+ else :
81
+ G1 = GRU (2 * hidden_size ,hidden_size )
82
+ if self .use_chars :
83
+ G2 = GRU (3 * embed_size // 2 ,hidden_size )
84
+ else :
85
+ G2 = GRU (embed_size ,hidden_size )
86
+ self .grus_docs .append (G1 )
87
+ self .grus_query .append (G2 )
88
+
89
+ if use_features :
90
+ self .features = nn .Embedding (2 ,2 )
91
+ self .finalgru_doc = GRU (2 * hidden_size + use_features * 2 ,hidden_size )
92
+ self .finalgru_query = GRU (3 * embed_size // 2 ,hidden_size )
93
+
94
+ if use_chars :
95
+ self .char_embeds = Char_Embeds (n_chars ,char_size ,embed_size ,hidden_size_char )
96
+
97
+ def forward (self ,doc ,doc_char ,doc_mask ,query ,query_char ,query_mask ,
98
+ char_type ,char_type_mask ,ans ,cloze ,cands ,cand_mask ,qe_comm ):
99
+ doc_embed = self .embedding (doc )
100
+ query_embed = self .embedding (query )
101
+
102
+ if self .use_chars :
103
+ doc_char_embed ,query_char_embed = self .char_embeds (char_type ,char_type_mask ,doc_char ,query_char )
104
+ doc_embed = torch .cat ([doc_embed ,doc_char_embed ],dim = - 1 )
105
+ query_embed = torch .cat ([query_embed ,query_char_embed ],dim = - 1 )
106
+
107
+ for i in range (self .gru_layers - 1 ):
108
+ doc_D ,_ = self .grus_docs [i ](doc_embed ,doc_mask )
109
+ Q ,_ = self .grus_query [i ](query_embed ,query_mask )
110
+
111
+ doc_embed = self .attention (doc_D ,Q ,doc_mask ,query_mask )
112
+
113
+ if self .use_features :
114
+ features = self .features (qe_comm )
115
+ D = torch .cat ([doc_embed ,features ],dim = - 1 )
116
+
117
+ final_doc ,_ = self .finalgru_doc (D ,doc_mask )
118
+ final_query ,_ = self .finalgru_query (query_embed ,query_mask )
119
+ output = self .attention_sum (final_doc ,final_query ,cloze ,cands ,cand_mask )
120
+
121
+ return output
122
+
123
+ def attention (self ,D ,Q ,doc_mask ,query_mask ):
124
+ mask_Q = query_mask .unsqueeze (1 ).expand (- 1 ,D .shape [1 ],- 1 )
125
+ mask_D = doc_mask .unsqueeze (- 1 ).expand (- 1 ,- 1 ,Q .shape [1 ])
126
+ attn = F .softmax (torch .bmm (D ,Q .transpose (- 1 ,- 2 )),dim = - 1 )* mask_Q * mask_D
127
+
128
+ weights = torch .bmm (attn ,Q )
129
+ output = weights * D
130
+
131
+ return output
132
+
133
+ def attention_sum (self ,doc ,query ,cloze ,cand ,cand_mask ):
134
+ mask = cloze .view (- 1 ,1 ).unsqueeze (- 1 ).expand (- 1 ,query .shape [1 ],query .shape [- 1 ])
135
+ q = query .gather (1 ,mask )
136
+ q = q [:,0 ,:].view (query .shape [0 ],- 1 ,1 )
137
+ distribution = torch .bmm (doc ,q ).squeeze ()
138
+
139
+ probs = F .softmax (distribution ,dim = - 1 )* cand_mask
140
+ output = torch .bmm (probs .unsqueeze (1 ),cand .float ()).squeeze ()
141
+
142
+ return output
0 commit comments