1
+ import random
2
+ import torch
3
+
4
+
5
+ class DataLoader (object ):
6
+ def __init__ (self ,training_data ,batch_size ):
7
+ self .training_data = training_data
8
+ self .batch_size = batch_size
9
+
10
+ def __load_next__ (self ):
11
+ data = random .choices (self .training_data ,k = self .batch_size )
12
+
13
+ max_query_len ,max_doc_len ,max_cand_len ,max_word_len = 0 ,0 ,0 ,0
14
+ ans = []
15
+ clozes = []
16
+ word_types = {}
17
+ for i ,instance in enumerate (data ):
18
+ doc ,query ,doc_char ,query_char ,cand ,ans_ ,cloze_ = instance
19
+ max_doc_len = max (max_doc_len ,len (doc ))
20
+ max_query_len = max (max_query_len ,len (query ))
21
+ max_cand_len = max (max_cand_len ,len (cand ))
22
+ ans .append (ans_ [0 ])
23
+ clozes .append (cloze_ [0 ])
24
+
25
+ for index ,word in enumerate (doc_char ):
26
+ max_word_len = max (max_word_len ,len (word ))
27
+ if tuple (word ) not in word_types :
28
+ word_types [tuple (word )]= []
29
+ word_types [tuple (word )].append ((1 ,i ,index ))
30
+ for index ,word in enumerate (query_char ):
31
+ max_word_len = max (max_word_len ,len (word ))
32
+ if tuple (word ) not in word_types :
33
+ word_types [tuple (word )]= []
34
+ word_types [tuple (word )].append ((0 ,i ,index ))
35
+
36
+ docs = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
37
+ queries = torch .zeros (self .batch_size ,max_query_len ,dtype = torch .long )
38
+ cands = torch .zeros (self .batch_size ,max_doc_len ,max_cand_len ,dtype = torch .long )
39
+ docs_mask = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
40
+ queries_mask = torch .zeros (self .batch_size ,max_query_len ,dtype = torch .long )
41
+ cand_mask = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
42
+ qe_comm = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
43
+ answers = torch .tensor (ans ,dtype = torch .long )
44
+ clozes = torch .tensor (clozes ,dtype = torch .long )
45
+
46
+ for i ,instance in enumerate (data ):
47
+ doc ,query ,doc_char ,query_char ,cand ,ans_ ,cloze_ = instance
48
+ docs [i ,:len (doc )]= torch .tensor (doc )
49
+ queries [i ,:len (query )]= torch .tensor (query )
50
+ docs_mask [i ,:len (doc )]= 1
51
+ queries_mask [i ,:len (query )]= 1
52
+
53
+ for k ,index in enumerate (doc ):
54
+ for j ,index_c in enumerate (cand ):
55
+ if index == index_c :
56
+ cands [i ][k ][j ]= 1
57
+ cand_mask [i ][k ]= 1
58
+
59
+ for y in query :
60
+ if y == index :
61
+ qe_comm [i ][k ]= 1
62
+ break
63
+
64
+ for x ,cl in enumerate (cand ):
65
+ if cl == answers [i ]:
66
+ answers [i ]= x
67
+ break
68
+
69
+ doc_char = torch .zeros (self .batch_size ,max_doc_len ,dtype = torch .long )
70
+ query_char = torch .zeros (self .batch_size ,max_query_len ,dtype = torch .long )
71
+ char_type = torch .zeros (len (word_types ),max_word_len ,dtype = torch .long )
72
+ char_type_mask = torch .zeros (len (word_types ),max_word_len ,dtype = torch .long )
73
+
74
+ index = 0
75
+ for word ,word_list in word_types .items ():
76
+ char_type [index ,:len (word )]= torch .tensor (list (word ))
77
+ char_type_mask [index ,:len (word )]= 1
78
+ for (i ,j ,k ) in word_list :
79
+ if i == 1 :
80
+ doc_char [j ,k ]= index
81
+ else :
82
+ query_char [j ,k ]= index
83
+ index += 1
84
+
85
+ return docs ,doc_char ,docs_mask ,queries ,query_char ,queries_mask , \
86
+ char_type ,char_type_mask ,answers ,clozes ,cands ,cand_mask ,qe_comm
0 commit comments