-
Notifications
You must be signed in to change notification settings - Fork 4
/
sampler.py
186 lines (157 loc) · 7.19 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import random
from multiprocessing.dummy import Pool
import multiprocessing
class BatchNegTypeSampler:
def __init__(self, kb, pos_per_batch, neg_per_pos=200, which_set="train", type_constraint=True):
self.kb = kb
self.pos_per_batch = pos_per_batch
self.neg_per_pos = neg_per_pos
self.type_constraint = type_constraint
self.facts = [f[0] for f in self.kb.get_all_facts() if f[2] == which_set]
self.num_facts = len(self.facts)
self.epoch_size = self.num_facts / self.pos_per_batch
self.reset()
self.__pool = Pool()
self._objs = list(self.kb.get_symbols(2))
self._subjs = list(self.kb.get_symbols(1))
# we use sampling with type constraints
#if type_constraint:
# self.init_types()
def init_types(self):
# add types to concepts
concept_types = dict()
self.rel_args = dict()
self.rel_types = dict()
for rel, subj, obj in self.facts:
subj_role = rel + "_s"
if subj_role not in self.rel_args:
self.rel_args[subj_role] = set()
self.rel_args[subj_role].add(subj)
if subj not in concept_types:
concept_types[subj] = set()
concept_types[subj].add(subj_role)
obj_role = rel + "_o"
if obj_role not in self.rel_args:
self.rel_args[obj_role] = set()
self.rel_args[obj_role].add(subj)
if obj not in concept_types:
concept_types[obj] = set()
concept_types[obj].add(obj_role)
# count types for positions in relation
rel_types = dict()
for rel, subj, obj in self.facts:
subj_role = rel + "_s"
obj_role = rel + "_o"
if subj_role not in rel_types:
rel_types[subj_role] = dict()
rel_types[obj_role] = dict()
subj_ts = rel_types[subj_role]
obj_ts = rel_types[obj_role]
for t in concept_types[subj]:
if t not in subj_ts:
subj_ts[t] = 0
subj_ts[t] += 1
for t in concept_types[obj]:
if t not in obj_ts:
obj_ts[t] = 0
obj_ts[t] += 1
# sort types for relations by count
for rel_role, types in rel_types.iteritems():
if rel_role not in self.rel_types:
self.rel_types[rel_role] = [] # distinction between subj and obj types
self.rel_types[rel_role].extend(map(lambda x: x[0], sorted(types.items(), key=lambda x:-x[1])))
# @profile
def reset(self):
self.todo_facts = list(xrange(self.num_facts))
random.shuffle(self.todo_facts)
self.todo_facts = self.todo_facts[:-(self.num_facts % self.pos_per_batch)]
self.count = 0
def end_of_epoch(self):
return self.count == self.epoch_size
def __iter__(self):
return self
def next(self):
if self.count >= self.num_facts:
self.reset()
raise StopIteration
return self.get_batch()
def __get_neg_examples(self, triple, position):
(rel, subj, obj) = triple
dim = 2 if position == "obj" else 1
#allowed = self.kb.get_symbols(dim)
disallowed = obj if position == "obj" else subj
if self.type_constraint:
#sample by type
#neg_candidates = set()
#typs = self.rel_types[rel+"_o"] if position == "obj" else self.rel_types[rel+"_s"]
# add negative neg_candidates until there are enough negative neg_candidates
#i = 0
#while i < len(typs) and len(neg_candidates) < self.neg_per_pos:
# typ = typs[i]
# i += 1
# for c in self.rel_args[typ]:
# if c != disallowed and c in allowed:
# neg_candidates.add(c)
neg_candidates = list(self.kb.compatible_args_of(dim, rel))
else: # sample from all candidates
neg_candidates = self._objs if position == "obj" else self._subjs
neg_triples = list()
# sampling code is optimized; no use of remove for lists (since it is O(n))
if position == "obj":
last = len(neg_candidates)-1 # index of last good candidate
for _ in xrange(self.neg_per_pos):
x = None
while not x or x == disallowed or self.kb.contains_fact(True, "train", rel, subj, x):
i = random.randint(0, last)
x = neg_candidates[i]
if neg_candidates is not self._objs: # do not change self._objs, accidental doubles are very rare
# remove candidate efficiently from candidates
if i != last:
neg_candidates[i] = neg_candidates[last] # copy last good candidate to position i
last -= 1
if last == -1:
neg_candidates = self._objs # fallback
last = len(neg_candidates) - 1
neg_triples.append((rel, subj, x))
else:
last = len(neg_candidates)-1 # index of last good candidate
for _ in xrange(self.neg_per_pos):
x = None
while not x or x == disallowed or self.kb.contains_fact(True, "train", rel, x, obj):
i = random.randint(0, last)
x = neg_candidates[i]
# remove candidate efficiently from candidates
if neg_candidates is not self._subjs: # do not change self._subjs
if i != last:
neg_candidates[i] = neg_candidates[last] # copy last good candidate to position i
last -= 1
if last == -1:
neg_candidates = self._subjs # fallback
last = len(neg_candidates) - 1
neg_triples.append((rel, x, obj))
return neg_triples
# @profile
def get_batch(self, position="both"):
if self.end_of_epoch():
self.reset()
pos_idx = self.todo_facts[0:self.pos_per_batch]
self.count += 1
self.todo_facts = self.todo_facts[self.pos_per_batch::]
if position == "both":
pos = [self.facts[pos_idx[i % self.pos_per_batch]] for i in xrange(self.pos_per_batch*2)]
else:
pos = [self.facts[i] for i in pos_idx]
if position == "both":
negs = self.__pool.map(
lambda i: self.__get_neg_examples(pos[i], "obj") if i < self.pos_per_batch else
self.__get_neg_examples(pos[i], "subj"),
xrange(self.pos_per_batch*2))
if position == "subj":
negs = self.__pool.map(lambda fact: self.__get_neg_examples(fact, "subj"), pos)
if position == "obj":
negs = self.__pool.map(lambda fact: self.__get_neg_examples(fact, "obj"), pos)
return pos, negs
def get_batch_async(self, position="both"):
return self.__pool.apply_async(self.get_batch, (position,))
def get_epoch(self):
return self.count / float(self.num_facts)