diff --git a/transD.py b/transD.py index 9e955d2..e6063c3 100755 --- a/transD.py +++ b/transD.py @@ -16,6 +16,7 @@ def __init__(self): lib.setInPath("./data/FB15K/") test_lib.setInPath("./data/FB15K/") lib.setBernFlag(0) + self.learning_rate = 0.001 self.testFlag = False self.loadFromData = False self.L1_flag = True @@ -29,7 +30,7 @@ def __init__(self): class TransDModel(object): def calc(self, e, t, r): - return e + tf.reduce_sum(e * t, 1, keep_dims = True) * r + return tf.nn.l2_normalize(e + tf.reduce_sum(e * t, 1, keep_dims = True) * r, 1) def __init__(self, config): @@ -106,7 +107,7 @@ def main(_): trainModel = TransDModel(config = config) global_step = tf.Variable(0, name="global_step", trainable=False) - optimizer = tf.train.GradientDescentOptimizer(0.001) + optimizer = tf.train.GradientDescentOptimizer(config.learning_rate) grads_and_vars = optimizer.compute_gradients(trainModel.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) saver = tf.train.Saver() diff --git a/transE.py b/transE.py index f4a1065..5739049 100755 --- a/transE.py +++ b/transE.py @@ -16,6 +16,7 @@ def __init__(self): lib.setInPath("./data/FB15K/") test_lib.setInPath("./data/FB15K/") lib.setBernFlag(0) + self.learning_rate = 0.001 self.testFlag = False self.loadFromData = False self.L1_flag = True @@ -66,6 +67,7 @@ def __init__(self, config): with tf.name_scope("output"): self.loss = tf.reduce_sum(tf.maximum(pos - neg + margin, 0)) + def main(_): config = Config() if (config.testFlag): @@ -88,7 +90,7 @@ def main(_): trainModel = TransEModel(config = config) global_step = tf.Variable(0, name="global_step", trainable=False) - optimizer = tf.train.GradientDescentOptimizer(0.001) + optimizer = tf.train.GradientDescentOptimizer(config.learning_rate) grads_and_vars = optimizer.compute_gradients(trainModel.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) saver = tf.train.Saver() diff --git a/transH.py b/transH.py index 684a5da..9398431 100755 --- a/transH.py +++ b/transH.py @@ -16,6 +16,7 @@ def __init__(self): lib.setInPath("./data/FB15K/") test_lib.setInPath("./data/FB15K/") lib.setBernFlag(0) + self.learning_rate = 0.001 self.testFlag = False self.loadFromData = False self.L1_flag = True @@ -103,7 +104,7 @@ def main(_): trainModel = TransHModel(config = config) global_step = tf.Variable(0, name="global_step", trainable=False) - optimizer = tf.train.GradientDescentOptimizer(0.001) + optimizer = tf.train.GradientDescentOptimizer(config.learning_rate) grads_and_vars = optimizer.compute_gradients(trainModel.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) saver = tf.train.Saver() diff --git a/transR.py b/transR.py index 366bb9a..7495242 100755 --- a/transR.py +++ b/transR.py @@ -16,6 +16,7 @@ def __init__(self): lib.setInPath("./data/FB15K/") test_lib.setInPath("./data/FB15K/") lib.setBernFlag(0) + self.learning_rate = 0.0001 self.testFlag = False self.loadFromData = False self.L1_flag = True @@ -24,12 +25,23 @@ def __init__(self): self.nbatches = 100 self.entity = 0 self.relation = 0 - self.trainTimes = 3000 + self.trainTimes = 1000 self.margin = 1.0 - + ''' + In the original paper, TransR is trained with the pre-trained embeddings as parameter initialization. + If you do not want to train with the pre-trained embeddings, you can use the following code instead of the default version. + You need to use np.savetxt() to store the pre-trained embeddings into the corresponding file and input the file's name. + e.g. + self.ent_init = "ent_embeddings.txt" + self.rel_init = "rel_embeddings.txt" + where entity and relation embeddings are stored into the "ent_embeddings.txt" and "rel_embeddings.txt" respectively. + ''' + self.rel_init = None + self.ent_init = None + class TransRModel(object): - def __init__(self, config): + def __init__(self, config, ent_init = None, rel_init = None): entity_total = config.entity relation_total = config.relation @@ -47,9 +59,22 @@ def __init__(self, config): self.neg_r = tf.placeholder(tf.int32, [batch_size]) with tf.name_scope("embedding"): - self.ent_embeddings = tf.get_variable(name = "ent_embedding", shape = [entity_total, sizeE], initializer = tf.contrib.layers.xavier_initializer(uniform = False)) - self.rel_embeddings = tf.get_variable(name = "rel_embedding", shape = [relation_total, sizeR], initializer = tf.contrib.layers.xavier_initializer(uniform = False)) - self.rel_matrix = tf.get_variable(name = "rel_matrix", shape = [relation_total, sizeE * sizeR], initializer = tf.contrib.layers.xavier_initializer(uniform = False)) + if ent_init != None: + self.ent_embeddings = tf.Variable(np.loadtxt(ent_init), name = "ent_embedding", dtype = np.float32) + else: + self.ent_embeddings = tf.get_variable(name = "ent_embedding", shape = [entity_total, sizeE], initializer = tf.contrib.layers.xavier_initializer(uniform = False)) + if rel_init != None: + self.rel_embeddings = tf.Variable(np.loadtxt(rel_init), name = "rel_embedding", dtype = np.float32) + else: + self.rel_embeddings = tf.get_variable(name = "rel_embedding", shape = [relation_total, sizeR], initializer = tf.contrib.layers.xavier_initializer(uniform = False)) + + rel_matrix = np.zeros([relation_total, sizeR * sizeE], dtype = np.float32) + for i in range(relation_total): + for j in range(sizeR): + for k in range(sizeE): + if j == k: + rel_matrix[i][j * sizeE + k] = 1.0 + self.rel_matrix = tf.Variable(rel_matrix, name = "rel_matrix") with tf.name_scope('lookup_embeddings'): pos_h_e = tf.reshape(tf.nn.embedding_lookup(self.ent_embeddings, self.pos_h), [-1, sizeE, 1]) @@ -58,12 +83,13 @@ def __init__(self, config): neg_h_e = tf.reshape(tf.nn.embedding_lookup(self.ent_embeddings, self.neg_h), [-1, sizeE, 1]) neg_t_e = tf.reshape(tf.nn.embedding_lookup(self.ent_embeddings, self.neg_t), [-1, sizeE, 1]) neg_r_e = tf.reshape(tf.nn.embedding_lookup(self.rel_embeddings, self.neg_r), [-1, sizeR]) - matrix = tf.reshape(tf.nn.embedding_lookup(self.rel_matrix, self.neg_r), [-1, sizeR, sizeE]) + pos_matrix = tf.reshape(tf.nn.embedding_lookup(self.rel_matrix, self.pos_r), [-1, sizeR, sizeE]) + neg_matrix = tf.reshape(tf.nn.embedding_lookup(self.rel_matrix, self.neg_r), [-1, sizeR, sizeE]) - pos_h_e = tf.reshape(tf.batch_matmul(matrix, pos_h_e), [-1, sizeR]) - pos_t_e = tf.reshape(tf.batch_matmul(matrix, pos_t_e), [-1, sizeR]) - neg_h_e = tf.reshape(tf.batch_matmul(matrix, neg_h_e), [-1, sizeR]) - neg_t_e = tf.reshape(tf.batch_matmul(matrix, neg_t_e), [-1, sizeR]) + pos_h_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(pos_matrix, pos_h_e), [-1, sizeR]), 1) + pos_t_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(pos_matrix, pos_t_e), [-1, sizeR]), 1) + neg_h_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(neg_matrix, neg_h_e), [-1, sizeR]), 1) + neg_t_e = tf.nn.l2_normalize(tf.reshape(tf.matmul(neg_matrix, neg_t_e), [-1, sizeR]), 1) if config.L1_flag: pos = tf.reduce_sum(abs(pos_h_e + pos_r_e - pos_t_e), 1, keep_dims = True) @@ -96,10 +122,10 @@ def main(_): with sess.as_default(): initializer = tf.contrib.layers.xavier_initializer(uniform = False) with tf.variable_scope("model", reuse=None, initializer = initializer): - trainModel = TransRModel(config = config) + trainModel = TransRModel(config = config, ent_init = config.ent_init, rel_init = config.rel_init) global_step = tf.Variable(0, name="global_step", trainable=False) - optimizer = tf.train.AdamOptimizer(0.001) + optimizer = tf.train.AdamOptimizer(config.learning_rate) grads_and_vars = optimizer.compute_gradients(trainModel.loss) train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step) saver = tf.train.Saver()