Skip to content

Commit 7eedef7

Browse files
add valid loader
1 parent c3e6d73 commit 7eedef7

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

data_loader.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import random
2+
import torch
3+
4+
5+
class DataLoader(object):
6+
def __init__(self,training_data,batch_size):
7+
self.training_data=training_data
8+
self.batch_size=batch_size
9+
10+
def __load_next__(self):
11+
data=random.choices(self.training_data,k=self.batch_size)
12+
13+
max_query_len,max_doc_len,max_cand_len,max_word_len=0,0,0,0
14+
ans=[]
15+
clozes=[]
16+
word_types={}
17+
for i,instance in enumerate(data):
18+
doc,query,doc_char,query_char,cand,ans_,cloze_=instance
19+
max_doc_len=max(max_doc_len,len(doc))
20+
max_query_len=max(max_query_len,len(query))
21+
max_cand_len=max(max_cand_len,len(cand))
22+
ans.append(ans_[0])
23+
clozes.append(cloze_[0])
24+
25+
for index,word in enumerate(doc_char):
26+
max_word_len=max(max_word_len,len(word))
27+
if tuple(word) not in word_types:
28+
word_types[tuple(word)]=[]
29+
word_types[tuple(word)].append((1,i,index))
30+
for index,word in enumerate(query_char):
31+
max_word_len=max(max_word_len,len(word))
32+
if tuple(word) not in word_types:
33+
word_types[tuple(word)]=[]
34+
word_types[tuple(word)].append((0,i,index))
35+
36+
docs=torch.zeros(self.batch_size,max_doc_len,dtype=torch.long)
37+
queries=torch.zeros(self.batch_size,max_query_len,dtype=torch.long)
38+
cands=torch.zeros(self.batch_size,max_doc_len,max_cand_len,dtype=torch.long)
39+
docs_mask=torch.zeros(self.batch_size,max_doc_len,dtype=torch.long)
40+
queries_mask=torch.zeros(self.batch_size,max_query_len,dtype=torch.long)
41+
cand_mask=torch.zeros(self.batch_size,max_doc_len,dtype=torch.long)
42+
qe_comm=torch.zeros(self.batch_size,max_doc_len,dtype=torch.long)
43+
answers=torch.tensor(ans,dtype=torch.long)
44+
clozes=torch.tensor(clozes,dtype=torch.long)
45+
46+
for i,instance in enumerate(data):
47+
doc,query,doc_char,query_char,cand,ans_,cloze_=instance
48+
docs[i,:len(doc)]=torch.tensor(doc)
49+
queries[i,:len(query)]=torch.tensor(query)
50+
docs_mask[i,:len(doc)]=1
51+
queries_mask[i,:len(query)]=1
52+
53+
for k,index in enumerate(doc):
54+
for j,index_c in enumerate(cand):
55+
if index==index_c:
56+
cands[i][k][j]=1
57+
cand_mask[i][k]=1
58+
59+
for y in query:
60+
if y==index:
61+
qe_comm[i][k]=1
62+
break
63+
64+
for x,cl in enumerate(cand):
65+
if cl==answers[i]:
66+
answers[i]=x
67+
break
68+
69+
doc_char=torch.zeros(self.batch_size,max_doc_len,dtype=torch.long)
70+
query_char=torch.zeros(self.batch_size,max_query_len,dtype=torch.long)
71+
char_type=torch.zeros(len(word_types),max_word_len,dtype=torch.long)
72+
char_type_mask=torch.zeros(len(word_types),max_word_len,dtype=torch.long)
73+
74+
index=0
75+
for word,word_list in word_types.items():
76+
char_type[index,:len(word)]=torch.tensor(list(word))
77+
char_type_mask[index,:len(word)]=1
78+
for (i,j,k) in word_list:
79+
if i==1:
80+
doc_char[j,k]=index
81+
else:
82+
query_char[j,k]=index
83+
index+=1
84+
85+
return docs,doc_char,docs_mask,queries,query_char,queries_mask, \
86+
char_type,char_type_mask,answers,clozes,cands,cand_mask,qe_comm

0 commit comments

Comments
 (0)