Skip to content

Commit

Permalink
We have added some essential normalization and fixed bugs in TransR.
Browse files Browse the repository at this point in the history
  • Loading branch information
韩旭 committed Apr 15, 2018
1 parent 42baded commit 1f2ce7e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 17 deletions.
5 changes: 3 additions & 2 deletions transD.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
lib.setInPath("./data/FB15K/")
test_lib.setInPath("./data/FB15K/")
lib.setBernFlag(0)
self.learning_rate = 0.001
self.testFlag = False
self.loadFromData = False
self.L1_flag = True
Expand All @@ -29,7 +30,7 @@ def __init__(self):
class TransDModel(object):

def calc(self, e, t, r):
return e + tf.reduce_sum(e * t, 1, keep_dims = True) * r
return tf.nn.l2_normalize(e + tf.reduce_sum(e * t, 1, keep_dims = True) * r, 1)

def __init__(self, config):

Expand Down Expand Up @@ -106,7 +107,7 @@ def main(_):
trainModel = TransDModel(config = config)

global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.GradientDescentOptimizer(0.001)
optimizer = tf.train.GradientDescentOptimizer(config.learning_rate)
grads_and_vars = optimizer.compute_gradients(trainModel.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
saver = tf.train.Saver()
Expand Down
4 changes: 3 additions & 1 deletion transE.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
lib.setInPath("./data/FB15K/")
test_lib.setInPath("./data/FB15K/")
lib.setBernFlag(0)
self.learning_rate = 0.001
self.testFlag = False
self.loadFromData = False
self.L1_flag = True
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(self, config):
with tf.name_scope("output"):
self.loss = tf.reduce_sum(tf.maximum(pos - neg + margin, 0))


def main(_):
config = Config()
if (config.testFlag):
Expand All @@ -88,7 +90,7 @@ def main(_):
trainModel = TransEModel(config = config)

global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.GradientDescentOptimizer(0.001)
optimizer = tf.train.GradientDescentOptimizer(config.learning_rate)
grads_and_vars = optimizer.compute_gradients(trainModel.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
saver = tf.train.Saver()
Expand Down
3 changes: 2 additions & 1 deletion transH.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
lib.setInPath("./data/FB15K/")
test_lib.setInPath("./data/FB15K/")
lib.setBernFlag(0)
self.learning_rate = 0.001
self.testFlag = False
self.loadFromData = False
self.L1_flag = True
Expand Down Expand Up @@ -103,7 +104,7 @@ def main(_):
trainModel = TransHModel(config = config)

global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.GradientDescentOptimizer(0.001)
optimizer = tf.train.GradientDescentOptimizer(config.learning_rate)
grads_and_vars = optimizer.compute_gradients(trainModel.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
saver = tf.train.Saver()
Expand Down
52 changes: 39 additions & 13 deletions transR.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self):
lib.setInPath("./data/FB15K/")
test_lib.setInPath("./data/FB15K/")
lib.setBernFlag(0)
self.learning_rate = 0.0001
self.testFlag = False
self.loadFromData = False
self.L1_flag = True
Expand All @@ -24,12 +25,23 @@ def __init__(self):
self.nbatches = 100
self.entity = 0
self.relation = 0
self.trainTimes = 3000
self.trainTimes = 1000
self.margin = 1.0

'''
In the original paper, TransR is trained with the pre-trained embeddings as parameter initialization.
If you do not want to train with the pre-trained embeddings, you can use the following code instead of the default version.
You need to use np.savetxt() to store the pre-trained embeddings into the corresponding file and input the file's name.
e.g.
self.ent_init = "ent_embeddings.txt"
self.rel_init = "rel_embeddings.txt"
where entity and relation embeddings are stored into the "ent_embeddings.txt" and "rel_embeddings.txt" respectively.
'''
self.rel_init = None
self.ent_init = None

class TransRModel(object):

def __init__(self, config):
def __init__(self, config, ent_init = None, rel_init = None):

entity_total = config.entity
relation_total = config.relation
Expand All @@ -47,9 +59,22 @@ def __init__(self, config):
self.neg_r = tf.placeholder(tf.int32, [batch_size])

with tf.name_scope("embedding"):
self.ent_embeddings = tf.get_variable(name = "ent_embedding", shape = [entity_total, sizeE], initializer = tf.contrib.layers.xavier_initializer(uniform = False))
self.rel_embeddings = tf.get_variable(name = "rel_embedding", shape = [relation_total, sizeR], initializer = tf.contrib.layers.xavier_initializer(uniform = False))
self.rel_matrix = tf.get_variable(name = "rel_matrix", shape = [relation_total, sizeE * sizeR], initializer = tf.contrib.layers.xavier_initializer(uniform = False))
if ent_init != None:
self.ent_embeddings = tf.Variable(np.loadtxt(ent_init), name = "ent_embedding", dtype = np.float32)
else:
self.ent_embeddings = tf.get_variable(name = "ent_embedding", shape = [entity_total, sizeE], initializer = tf.contrib.layers.xavier_initializer(uniform = False))
if rel_init != None:
self.rel_embeddings = tf.Variable(np.loadtxt(rel_init), name = "rel_embedding", dtype = np.float32)
else:
self.rel_embeddings = tf.get_variable(name = "rel_embedding", shape = [relation_total, sizeR], initializer = tf.contrib.layers.xavier_initializer(uniform = False))

rel_matrix = np.zeros([relation_total, sizeR * sizeE], dtype = np.float32)
for i in range(relation_total):
for j in range(sizeR):
for k in range(sizeE):
if j == k:
rel_matrix[i][j * sizeE + k] = 1.0
self.rel_matrix = tf.Variable(rel_matrix, name = "rel_matrix")

with tf.name_scope('lookup_embeddings'):
pos_h_e = tf.reshape(tf.nn.embedding_lookup(self.ent_embeddings, self.pos_h), [-1, sizeE, 1])
Expand All @@ -58,12 +83,13 @@ def __init__(self, config):
neg_h_e = tf.reshape(tf.nn.embedding_lookup(self.ent_embeddings, self.neg_h), [-1, sizeE, 1])
neg_t_e = tf.reshape(tf.nn.embedding_lookup(self.ent_embeddings, self.neg_t), [-1, sizeE, 1])
neg_r_e = tf.reshape(tf.nn.embedding_lookup(self.rel_embeddings, self.neg_r), [-1, sizeR])
matrix = tf.reshape(tf.nn.embedding_lookup(self.rel_matrix, self.neg_r), [-1, sizeR, sizeE])
pos_matrix = tf.reshape(tf.nn.embedding_lookup(self.rel_matrix, self.pos_r), [-1, sizeR, sizeE])
neg_matrix = tf.reshape(tf.nn.embedding_lookup(self.rel_matrix, self.neg_r), [-1, sizeR, sizeE])

pos_h_e = tf.reshape(tf.batch_matmul(matrix, pos_h_e), [-1, sizeR])
pos_t_e = tf.reshape(tf.batch_matmul(matrix, pos_t_e), [-1, sizeR])
neg_h_e = tf.reshape(tf.batch_matmul(matrix, neg_h_e), [-1, sizeR])
neg_t_e = tf.reshape(tf.batch_matmul(matrix, neg_t_e), [-1, sizeR])
pos_h_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(pos_matrix, pos_h_e), [-1, sizeR]), 1)
pos_t_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(pos_matrix, pos_t_e), [-1, sizeR]), 1)
neg_h_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(neg_matrix, neg_h_e), [-1, sizeR]), 1)
neg_t_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(neg_matrix, neg_t_e), [-1, sizeR]), 1)

if config.L1_flag:
pos = tf.reduce_sum(abs(pos_h_e + pos_r_e - pos_t_e), 1, keep_dims = True)
Expand Down Expand Up @@ -96,10 +122,10 @@ def main(_):
with sess.as_default():
initializer = tf.contrib.layers.xavier_initializer(uniform = False)
with tf.variable_scope("model", reuse=None, initializer = initializer):
trainModel = TransRModel(config = config)
trainModel = TransRModel(config = config, ent_init = config.ent_init, rel_init = config.rel_init)

global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.AdamOptimizer(0.001)
optimizer = tf.train.AdamOptimizer(config.learning_rate)
grads_and_vars = optimizer.compute_gradients(trainModel.loss)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
saver = tf.train.Saver()
Expand Down

0 comments on commit 1f2ce7e

Please sign in to comment.