-
Notifications
You must be signed in to change notification settings - Fork 13
/
cenet_model.py
417 lines (349 loc) · 19.9 KB
/
cenet_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
# Name: cenet_model
# Author: Reacubeth
# Time: 2021/6/25 17:28
# Mail: noverfitting@gmail.com
# Site: www.omegaxyz.com
# *_*coding:utf-8 *_*
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import *
import math
import copy
"""
class Oracle(nn.Module):
def __init__(self, input_dim, out_dim):
super(Oracle, self).__init__()
self.linear = nn.Sequential(nn.Linear(input_dim, 2 * input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(2 * input_dim, 2 * input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(2 * input_dim, 2 * input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(2 * input_dim, input_dim),
nn.Dropout(0.2),
nn.LeakyReLU(0.2),
nn.Linear(input_dim, out_dim),
)
def forward(self, x):
return self.linear(x)
"""
class Oracle(nn.Module):
def __init__(self, input_dim, out_dim):
super(Oracle, self).__init__()
self.linear = nn.Sequential(nn.Linear(input_dim, input_dim),
nn.BatchNorm1d(input_dim),
nn.Dropout(0.4),
nn.LeakyReLU(0.2),
nn.Linear(input_dim, out_dim),
)
def forward(self, x):
return self.linear(x)
class CENET(nn.Module):
def __init__(self, num_e, num_rel, num_t, args):
super(CENET, self).__init__()
# stats
self.num_e = num_e
self.num_t = num_t
self.num_rel = num_rel
self.args = args
# entity relation embedding
self.rel_embeds = nn.Parameter(torch.zeros(2 * num_rel, args.embedding_dim))
nn.init.xavier_uniform_(self.rel_embeds, gain=nn.init.calculate_gain('relu'))
self.entity_embeds = nn.Parameter(torch.zeros(self.num_e, args.embedding_dim))
nn.init.xavier_uniform_(self.entity_embeds, gain=nn.init.calculate_gain('relu'))
self.linear_frequency = nn.Linear(self.num_e, args.embedding_dim)
self.contrastive_hidden_layer = nn.Linear(3 * args.embedding_dim, args.embedding_dim)
self.contrastive_output_layer = nn.Linear(args.embedding_dim, args.embedding_dim)
self.oracle_layer = Oracle(3 * args.embedding_dim, 1)
self.oracle_layer.apply(self.weights_init)
self.linear_pred_layer_s1 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.linear_pred_layer_o1 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.linear_pred_layer_s2 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.linear_pred_layer_o2 = nn.Linear(2 * args.embedding_dim, args.embedding_dim)
self.weights_init(self.linear_frequency)
self.weights_init(self.linear_pred_layer_s1)
self.weights_init(self.linear_pred_layer_o1)
self.weights_init(self.linear_pred_layer_s2)
self.weights_init(self.linear_pred_layer_o2)
"""
pe = torch.zeros(400, 3 * args.embedding_dim)
position = torch.arange(0, 400, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, 3 * args.embedding_dim, 2).float() * (-math.log(10000.0) / (3 * args.embedding_dim)))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
"""
self.dropout = nn.Dropout(args.dropout)
self.logSoftmax = nn.LogSoftmax()
self.softmax = nn.Softmax()
self.tanh = nn.Tanh()
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
self.crossEntropy = nn.BCELoss()
self.oracle_mode = args.oracle_mode
print('CENET Initiated')
@staticmethod
def weights_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
def forward(self, batch_block, mode_lk, total_data=None):
quadruples, s_history_event_o, o_history_event_s, \
s_history_label_true, o_history_label_true, s_frequency, o_frequency = batch_block
if isListEmpty(s_history_event_o) or isListEmpty(o_history_event_s):
sub_rank, obj_rank, batch_loss = [None] * 3
if mode_lk == 'Training':
return batch_loss
elif mode_lk in ['Valid', 'Test']:
return sub_rank, batch_loss
else:
return None
s = quadruples[:, 0]
r = quadruples[:, 1]
o = quadruples[:, 2]
"""
t = (quadruples[:, 3] / 24.0).long()
time_embedding = self.pe[t]
"""
s_history_tag = copy.deepcopy(s_frequency)
o_history_tag = copy.deepcopy(o_frequency)
s_non_history_tag = copy.deepcopy(s_frequency)
o_non_history_tag = copy.deepcopy(o_frequency)
s_history_tag[s_history_tag != 0] = self.args.lambdax
o_history_tag[o_history_tag != 0] = self.args.lambdax
s_non_history_tag[s_history_tag == 1] = -self.args.lambdax
s_non_history_tag[s_history_tag == 0] = self.args.lambdax
o_non_history_tag[o_history_tag == 1] = -self.args.lambdax
o_non_history_tag[o_history_tag == 0] = self.args.lambdax
s_history_tag[s_history_tag == 0] = -self.args.lambdax
o_history_tag[o_history_tag == 0] = -self.args.lambdax
s_frequency = F.softmax(s_frequency, dim=1)
o_frequency = F.softmax(o_frequency, dim=1)
s_frequency_hidden = self.tanh(self.linear_frequency(s_frequency))
o_frequency_hidden = self.tanh(self.linear_frequency(o_frequency))
if mode_lk == 'Training':
s_nce_loss, _ = self.calculate_nce_loss(s, o, r, self.rel_embeds[:self.num_rel],
self.linear_pred_layer_s1, self.linear_pred_layer_s2,
s_history_tag, s_non_history_tag)
o_nce_loss, _ = self.calculate_nce_loss(o, s, r, self.rel_embeds[self.num_rel:],
self.linear_pred_layer_o1, self.linear_pred_layer_o2,
o_history_tag, o_non_history_tag)
# calculate_spc_loss(self, hidden_lk, actor1, r, rel_embeds, targets):
s_spc_loss = self.calculate_spc_loss(s, r, self.rel_embeds[:self.num_rel],
s_history_label_true, s_frequency_hidden)
o_spc_loss = self.calculate_spc_loss(o, r, self.rel_embeds[self.num_rel:],
o_history_label_true, o_frequency_hidden)
nce_loss = (s_nce_loss + o_nce_loss) / 2.0
spc_loss = (s_spc_loss + o_spc_loss) / 2.0
# print('nce loss', nce_loss.item(), ' spc loss', spc_loss.item())
return self.args.alpha * nce_loss + (1 - self.args.alpha) * spc_loss
elif mode_lk in ['Valid', 'Test']:
s_history_oid = []
o_history_sid = []
for i in range(quadruples.shape[0]):
s_history_oid.append([])
o_history_sid.append([])
for con_events in s_history_event_o[i]:
s_history_oid[-1] += con_events[:, 1].tolist()
for con_events in o_history_event_s[i]:
o_history_sid[-1] += con_events[:, 1].tolist()
s_nce_loss, s_preds = self.calculate_nce_loss(s, o, r, self.rel_embeds[:self.num_rel],
self.linear_pred_layer_s1, self.linear_pred_layer_s2,
s_history_tag, s_non_history_tag)
o_nce_loss, o_preds = self.calculate_nce_loss(o, s, r, self.rel_embeds[self.num_rel:],
self.linear_pred_layer_o1, self.linear_pred_layer_o2,
o_history_tag, o_non_history_tag)
s_ce_loss, s_pred_history_label, s_ce_all_acc = self.oracle_loss(s, r, self.rel_embeds[:self.num_rel],
s_history_label_true, s_frequency_hidden)
o_ce_loss, o_pred_history_label, o_ce_all_acc = self.oracle_loss(o, r, self.rel_embeds[self.num_rel:],
o_history_label_true, o_frequency_hidden)
s_mask = to_device(torch.zeros(quadruples.shape[0], self.num_e))
o_mask = to_device(torch.zeros(quadruples.shape[0], self.num_e))
for i in range(quadruples.shape[0]):
if s_pred_history_label[i].item() > 0.5:
s_mask[i, s_history_oid[i]] = 1
else:
s_mask[i, :] = 1
s_mask[i, s_history_oid[i]] = 0
if o_pred_history_label[i].item() > 0.5:
o_mask[i, o_history_sid[i]] = 1
else:
o_mask[i, :] = 1
o_mask[i, o_history_sid[i]] = 0
if self.oracle_mode == 'soft':
s_mask = F.softmax(s_mask, dim=1)
o_mask = F.softmax(o_mask, dim=1)
s_total_loss1, sub_rank1 = self.link_predict(s_nce_loss, s_preds, s_ce_loss, s, o, r,
s_mask, total_data, 's', True)
o_total_loss1, obj_rank1 = self.link_predict(o_nce_loss, o_preds, o_ce_loss, o, s, r,
o_mask, total_data, 'o', True)
batch_loss1 = (s_total_loss1 + o_total_loss1) / 2.0
s_total_loss2, sub_rank2 = self.link_predict(s_nce_loss, s_preds, s_ce_loss, s, o, r,
s_mask, total_data, 's', False)
o_total_loss2, obj_rank2 = self.link_predict(o_nce_loss, o_preds, o_ce_loss, o, s, r,
o_mask, total_data, 'o', False)
batch_loss2 = (s_total_loss2 + o_total_loss2) / 2.0
# Ground Truth
s_mask_gt = to_device(torch.zeros(quadruples.shape[0], self.num_e))
o_mask_gt = to_device(torch.zeros(quadruples.shape[0], self.num_e))
for i in range(quadruples.shape[0]):
if o[i] in s_history_oid[i]:
s_mask_gt[i, s_history_oid[i]] = 1
else:
s_mask_gt[i, :] = 1
s_mask_gt[i, s_history_oid[i]] = 0
if s[i] in o_history_sid[i]:
o_mask_gt[i, o_history_sid[i]] = 1
else:
o_mask_gt[i, :] = 1
o_mask_gt[i, o_history_sid[i]] = 0
s_total_loss3, sub_rank3 = self.link_predict(s_nce_loss, s_preds, s_ce_loss, s, o, r,
s_mask_gt, total_data, 's', True)
o_total_loss3, obj_rank3 = self.link_predict(o_nce_loss, o_preds, o_ce_loss, o, s, r,
o_mask_gt, total_data, 'o', True)
batch_loss3 = (s_total_loss3 + o_total_loss3) / 2.0
return sub_rank1, obj_rank1, batch_loss1, \
sub_rank2, obj_rank2, batch_loss2, \
sub_rank3, obj_rank3, batch_loss3, \
(s_ce_all_acc + o_ce_all_acc) / 2
elif mode_lk == 'Oracle':
print('Oracle Training')
s_ce_loss, _, _ = self.oracle_loss(s, r, self.rel_embeds[:self.num_rel],
s_history_label_true, s_frequency_hidden)
o_ce_loss, _, _ = self.oracle_loss(o, r, self.rel_embeds[self.num_rel:],
o_history_label_true, o_frequency_hidden)
return (s_ce_loss + o_ce_loss) / 2.0 + self.oracle_l1(0.01)
def oracle_loss(self, actor1, r, rel_embeds, history_label, frequency_hidden):
history_label_pred = F.sigmoid(
self.oracle_layer(torch.cat((self.entity_embeds[actor1], rel_embeds[r], frequency_hidden), dim=1)))
tmp_label = torch.squeeze(history_label_pred).clone().detach()
tmp_label[torch.where(tmp_label > 0.5)[0]] = 1
tmp_label[torch.where(tmp_label < 0.5)[0]] = 0
# print('# Bias Ratio', torch.sum(tmp_label).item() / tmp_label.shape[0])
ce_correct = torch.sum(torch.eq(tmp_label, torch.squeeze(history_label)))
ce_accuracy = 1. * ce_correct.item() / tmp_label.shape[0]
print('# CE Accuracy', ce_accuracy)
ce_loss = self.crossEntropy(torch.squeeze(history_label_pred), torch.squeeze(history_label))
return ce_loss, history_label_pred, ce_accuracy * tmp_label.shape[0]
def calculate_nce_loss(self, actor1, actor2, r, rel_embeds, linear1, linear2, history_tag, non_history_tag):
preds_raw1 = self.tanh(linear1(
self.dropout(torch.cat((self.entity_embeds[actor1], rel_embeds[r]), dim=1))))
preds1 = F.softmax(preds_raw1.mm(self.entity_embeds.transpose(0, 1)) + history_tag, dim=1)
preds_raw2 = self.tanh(linear2(
self.dropout(torch.cat((self.entity_embeds[actor1], rel_embeds[r]), dim=1))))
preds2 = F.softmax(preds_raw2.mm(self.entity_embeds.transpose(0, 1)) + non_history_tag, dim=1)
# cro_entr_loss = self.criterion_link(preds1 + preds2, actor2)
nce = torch.sum(torch.gather(torch.log(preds1 + preds2), 1, actor2.view(-1, 1)))
nce /= -1. * actor2.shape[0]
pred_actor2 = torch.argmax(preds1 + preds2, dim=1) # predicted result
correct = torch.sum(torch.eq(pred_actor2, actor2))
accuracy = 1. * correct.item() / actor2.shape[0]
print('# Batch accuracy', accuracy)
return nce, preds1 + preds2
def link_predict(self, nce_loss, preds, ce_loss, actor1, actor2, r, trust_musk, all_triples, pred_known, oracle,
history_tag=None, case_study=False):
if case_study:
# f = open("case_study.txt", "a+")
# entity2id, relation2id = get_entity_relation_set(self.args.dataset)
pass
if oracle:
preds = torch.mul(preds, trust_musk)
print('$Batch After Oracle accuracy:', end=' ')
else:
print('$Batch No Oracle accuracy:', end=' ')
# compute the correct triples
pred_actor2 = torch.argmax(preds, dim=1) # predicted result
correct = torch.sum(torch.eq(pred_actor2, actor2))
accuracy = 1. * correct.item() / actor2.shape[0]
print(accuracy)
# print('Batch Error', 1 - accuracy)
total_loss = nce_loss + ce_loss
ranks = []
for i in range(preds.shape[0]):
cur_s = actor1[i]
cur_r = r[i]
cur_o = actor2[i]
if case_study:
in_history = torch.where(history_tag[i] > 0)[0]
not_in_history = torch.where(history_tag[i] < 0)[0]
print('---------------------------', file=f)
for hh in range(in_history.shape[0]):
print('his:', entity2id[in_history[hh].item()], file=f)
print(pred_known,
'Truth:', entity2id[cur_s.item()], '--', relation2id[cur_r.item()], '--', entity2id[cur_o.item()],
'Prediction:', entity2id[pred_actor2[i].item()], file=f)
o_label = cur_o
ground = preds[i, cur_o].clone().item()
if self.args.filtering:
if pred_known == 's':
s_id = torch.nonzero(all_triples[:, 0] == cur_s).view(-1)
idx = torch.nonzero(all_triples[s_id, 1] == cur_r).view(-1)
idx = s_id[idx]
idx = all_triples[idx, 2]
else:
s_id = torch.nonzero(all_triples[:, 2] == cur_s).view(-1)
idx = torch.nonzero(all_triples[s_id, 1] == cur_r).view(-1)
idx = s_id[idx]
idx = all_triples[idx, 0]
preds[i, idx] = 0
preds[i, o_label] = ground
ob_pred_comp1 = (preds[i, :] > ground).data.cpu().numpy()
ob_pred_comp2 = (preds[i, :] == ground).data.cpu().numpy()
ranks.append(np.sum(ob_pred_comp1) + ((np.sum(ob_pred_comp2) - 1.0) / 2) + 1)
return total_loss, ranks
def regularization_loss(self, reg_param):
regularization_loss = torch.mean(self.rel_embeds.pow(2)) + torch.mean(self.entity_embeds.pow(2))
return regularization_loss * reg_param
def oracle_l1(self, reg_param):
reg = 0
for param in self.oracle_layer.parameters():
reg += torch.sum(torch.abs(param))
return reg * reg_param
# contrastive
def freeze_parameter(self):
self.rel_embeds.requires_grad_(False)
self.entity_embeds.requires_grad_(False)
self.linear_pred_layer_s1.requires_grad_(False)
self.linear_pred_layer_o1.requires_grad_(False)
self.linear_pred_layer_s2.requires_grad_(False)
self.linear_pred_layer_o2.requires_grad_(False)
self.linear_frequency.requires_grad_(False)
self.contrastive_hidden_layer.requires_grad_(False)
self.contrastive_output_layer.requires_grad_(False)
def contrastive_layer(self, x):
# Implement from the encoder E to the projection network P
# x = F.normalize(x, dim=1)
x = self.contrastive_hidden_layer(x)
# x = F.relu(x)
# x = self.contrastive_output_layer(x)
# Normalize to unit hypersphere
# x = F.normalize(x, dim=1)
return x
def calculate_spc_loss(self, actor1, r, rel_embeds, targets, frequency_hidden):
projections = self.contrastive_layer(
torch.cat((self.entity_embeds[actor1], rel_embeds[r], frequency_hidden), dim=1))
targets = torch.squeeze(targets)
"""if np.random.randint(0, 10) < 1 and torch.sum(targets) / targets.shape[0] < 0.65 and torch.sum(targets) / targets.shape[0] > 0.35:
np.savetxt("xx.tsv", projections.detach().cpu().numpy(), delimiter="\t")
np.savetxt("yy.tsv", targets.detach().cpu().numpy(), delimiter="\t")
"""
dot_product_tempered = torch.mm(projections, projections.T) / 1.0
# Minus max for numerical stability with exponential. Same done in cross entropy. Epsilon added to avoid log(0)
exp_dot_tempered = (
torch.exp(dot_product_tempered - torch.max(dot_product_tempered, dim=1, keepdim=True)[0]) + 1e-5
)
mask_similar_class = to_device(targets.unsqueeze(1).repeat(1, targets.shape[0]) == targets)
mask_anchor_out = to_device(1 - torch.eye(exp_dot_tempered.shape[0]))
mask_combined = mask_similar_class * mask_anchor_out
cardinality_per_samples = torch.sum(mask_combined, dim=1)
log_prob = -torch.log(exp_dot_tempered / (torch.sum(exp_dot_tempered * mask_anchor_out, dim=1, keepdim=True)))
supervised_contrastive_loss_per_sample = torch.sum(log_prob * mask_combined, dim=1) / cardinality_per_samples
supervised_contrastive_loss = torch.mean(supervised_contrastive_loss_per_sample)
if torch.any(torch.isnan(supervised_contrastive_loss)):
return 0
return supervised_contrastive_loss