Skip to content

Commit

Permalink
* compositional models working for DistMult and ModelE
Browse files Browse the repository at this point in the history
* supported composition functions: 'LSTM', 'GRU', 'RNN', 'BoW', 'BiLSTM', 'BiGRU', 'BiRNN'
  • Loading branch information
dirkweissenborn committed Feb 10, 2016
1 parent 2e305aa commit 4cc1faf
Show file tree
Hide file tree
Showing 6 changed files with 559 additions and 164 deletions.
23 changes: 23 additions & 0 deletions data/load_fb15k237.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def _load_triples(fn, kb, typ="train"):
return triples


def split_relations(rel):
if "[XXX]" in rel:
dep_path_arr = []
c = 0
for i in xrange(len(rel)-2):
if rel[i:i+3] == ":<-":
if c > 0: # do not keep [XXX]
dep_path_arr.append(rel[c:i])
dep_path_arr.append(":<-")
c = i+3
elif rel[i:i+2] == ":<":
if c > 0:
dep_path_arr.append(rel[c:i])
dep_path_arr.append(":<")
c = i+2
elif rel[i:i+2] == ">:":
if c > 0:
dep_path_arr.append(rel[c:i])
c = i+2
return dep_path_arr
else:
return rel.split("/")

def _load_dep_paths(fn, kb, typ="train"):
with open(fn) as f:
for l in f:
Expand Down
29 changes: 0 additions & 29 deletions kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def __init__(self):
self.__maps = list()
# caches number of dimensions since len(...) is slow
self.__dims = list()
# global mapping of symbols to indices independent from dimension
self.__global_ids = {}
# lists compatible arguments for each arg position foreach relation
self.__compatible_args = dict()

Expand Down Expand Up @@ -116,21 +114,11 @@ def __add_to_facts(self, fact):
self.__facts[arity].append(fact)
self.__all_facts.add(fact)

def __add_word(self, word):
if word not in self.__global_ids:
self.__global_ids[word] = len(self.__global_ids)

def __add_to_symbols(self, key, dim):
if len(self.__symbols) <= dim:
self.__symbols.append(set())
self.__symbols[dim].add(key)

words = key
if isinstance(words, basestring):
words = [key]
for word in words:
self.__add_word(word)

def __add_to_vocab(self, key, dim):
if len(self.__vocab) <= dim:
self.__vocab.append(list())
Expand Down Expand Up @@ -241,23 +229,6 @@ def get_ids(self, *keys):
ids.append(self.get_id(keys[dim], dim))
return ids

def get_global_id(self, symbol):
return self.__global_ids[symbol]

def get_global_ids(self, *symbols):
ids = list()
for symbol in symbols:
# fixme
if not isinstance(symbol, basestring):
for s in symbol:
ids.append(self.get_global_id(s))
else:
ids.append(self.get_global_id(symbol))
return ids

def num_global_ids(self):
return len(self.__global_ids)

def get_key(self, id, dim):
return self.__vocab[dim][id]

Expand Down
58 changes: 58 additions & 0 deletions model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import tensorflow as tf
from models import *
from comp_models import *
from data.load_fb15k237 import split_relations

def default_init():
return tf.random_normal_initializer(0.0, 0.1)


def create_model(kb, size, batch_size, is_train=True, num_neg=200, learning_rate=1e-2,
l2_lambda=0.0, is_batch_training=False, type="DistMult",
observed_sets=["train_text"], composition=None, num_buckets= 10):
'''
Factory Method for all models
:param type: any or combination of "ModelF", "DistMult", "ModelE", "ModelO", "ModelN"
:param composition: "Tanh", "LSTM", "GRU", "BiTanh", "BiLSTM", "BiGRU", "BoW" or None
:return: Model(s) of type "type"
'''
if not isinstance(type, list):
if composition == "Tanh":
composition = TanhRNNCompModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
elif composition == "LSTM":
composition = LSTMCompModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
elif composition == "GRU":
composition = GRUCompModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
elif composition == "BiTanh":
composition = BiTanhRNNCompModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
elif composition == "BiLSTM":
composition = BiLSTMCompModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
elif composition == "BiGRU":
composition = BiGRUCompModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
elif composition == "BoW":
composition = CompositionModel(kb, size, num_buckets, split_relations, batch_size/(num_neg+1), learning_rate)
else:
composition = None

if type == "ModelF":
return ModelF(kb, size, batch_size, is_train, num_neg, learning_rate, l2_lambda, is_batch_training)
elif type == "DistMult":
if composition:
return CompDistMult(kb, size, batch_size, composition, is_train, num_neg, learning_rate)
else:
return DistMult(kb, size, batch_size, is_train, num_neg, learning_rate, l2_lambda, is_batch_training)
elif type == "ModelE":
if composition:
return CompModelE(kb, size, batch_size, composition, is_train, num_neg, learning_rate)
else:
return ModelE(kb, size, batch_size, is_train, num_neg, learning_rate, l2_lambda, is_batch_training)
elif type == "ModelO":
return ModelO(kb, size, batch_size, is_train, num_neg, learning_rate, l2_lambda, is_batch_training, observed_sets)
elif type == "ModelN":
return ModelN(kb, size, batch_size, is_train, num_neg, learning_rate, l2_lambda, is_batch_training, observed_sets)
else:
raise NameError("There is no model with type %s. "
"Possible values are 'ModelF', 'DistMult', 'ModelE', 'ModelO', 'ModelN'." % type)
else:
return CombinedModel(type, kb, size, batch_size, is_train, num_neg,
learning_rate, l2_lambda, is_batch_training, composition)
Loading

0 comments on commit 4cc1faf

Please sign in to comment.