-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add new restrict_vocab functionality, most_similar_among #1229
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -274,6 +274,13 @@ def word_vec(self, word, use_norm=False): | |
else: | ||
raise KeyError("word '%s' not in vocabulary" % word) | ||
|
||
def get_ordered_keys(self): | ||
""" | ||
Returns the keys in the current KeyedVectors instant as a list. | ||
If the model is not trained yet then an empty list is returned. | ||
""" | ||
return self.index2word | ||
|
||
def most_similar(self, positive=[], negative=[], topn=10, restrict_vocab=None, indexer=None): | ||
""" | ||
Find the top-N most similar words. Positive words contribute positively towards the | ||
|
@@ -338,6 +345,129 @@ def most_similar(self, positive=[], negative=[], topn=10, restrict_vocab=None, i | |
result = [(self.index2word[sim], float(dists[sim])) for sim in best if sim not in all_words] | ||
return result[:topn] | ||
|
||
def most_similar_among(self, positive=[], negative=[], | ||
topn=10, words_list=None, indexer=None, | ||
suppress_warnings=False): | ||
""" | ||
Find the top-N most similar words among words_list to given words. | ||
|
||
Positive words contribute positively towards the similarity, | ||
negative words negatively. | ||
|
||
Please refer to docs of most_similar function. | ||
|
||
If topn is False, most_similar returns the vector of similarity scores | ||
for all words in vocabulary of model, restriced by the supplied words_list. | ||
|
||
'words_list' should be a list/set of words. The returned word similarities | ||
will only contain similarity scores for those words that are in words_list | ||
(and in trained vocabulary). | ||
|
||
If some words in words_list are not in vocabulary then a warning is | ||
issued to the user. | ||
|
||
Warnings can be supressed by setting the suppress_warnings flag. | ||
|
||
Example:: | ||
|
||
>>> trained_model.most_similar_among(positive=['man'], topn=1, | ||
words_list=['woman','random_word']) | ||
[('woman', 0.75882536)] | ||
|
||
""" | ||
|
||
if isinstance(words_list, int): | ||
raise ValueError("words_list must be a set/list of words. " \ | ||
"Maybe you wanted the most_similar function.") | ||
elif isinstance(words_list, list) or isinstance(words_list, set): | ||
pass | ||
else: # This is triggered for empty words_list parameter | ||
raise ValueError("words_list must be set/list of words. " \ | ||
"Maybe you wanted the most_similar function. " \ | ||
"Please read doc string") | ||
|
||
if type(topn) is not int: | ||
if topn is False: | ||
pass | ||
else: | ||
if suppress_warnings is False: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This must be an exception, not a warning. Incorrect input can't be surpressed. |
||
logger.warning("topn needs to either be a number or False. " \ | ||
"Please read docstring. " \ | ||
"Displaying all similarities!") | ||
topn = len(self.index2word) | ||
|
||
self.init_sims() | ||
|
||
if isinstance(positive, string_types) and not negative: | ||
# allow calls like most_similar('dog'), | ||
# as a shorthand for most_similar(['dog']) | ||
positive = [positive] | ||
|
||
# add weights for each word, if not already present; | ||
# default to 1.0 for positive and -1.0 for negative words | ||
positive = [ | ||
(word, 1.0) if isinstance(word, string_types + (ndarray,)) else word | ||
for word in positive | ||
] | ||
negative = [ | ||
(word, -1.0) if isinstance(word, string_types + (ndarray,)) else word | ||
for word in negative | ||
] | ||
|
||
# compute the weighted average of all words | ||
all_words, mean = set(), [] | ||
for word, weight in positive + negative: | ||
if isinstance(word, ndarray): | ||
mean.append(weight * word) | ||
else: | ||
mean.append(weight * self.word_vec(word, use_norm=True)) | ||
if word in self.vocab: | ||
all_words.add(self.vocab[word].index) | ||
if mean is False: | ||
raise ValueError("cannot compute similarity with no input") | ||
mean = matutils.unitvec(array(mean).mean(axis=0)).astype(REAL) | ||
|
||
if indexer is not None: | ||
return indexer.most_similar(mean, topn) | ||
|
||
words_list = set(words_list) | ||
vocabulary_words = set(self.index2word) | ||
|
||
words_to_use = vocabulary_words.intersection(words_list) | ||
|
||
if not words_to_use: | ||
raise ValueError("None of the words in words_list " \ | ||
"exist in current vocabulary") | ||
|
||
if suppress_warnings is False: | ||
missing_words = words_list.difference(vocabulary_words) | ||
if not missing_words: # missing_words is empty | ||
pass | ||
else: | ||
logger.warning("The following words are not in " \ | ||
"trained vocabulary : %s", str(missing_words)) | ||
logger.info("This warning is expensive to calculate, " \ | ||
"especially for large words_list. " \ | ||
"If you would rather not remove the missing_words " \ | ||
"from words_list please set the " \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better message is "Please intersect with vocabulary words_to_use = vocabulary_words.intersection(words_list) prior to calling the most_similar_among". |
||
"suppress_warnings flag.") | ||
|
||
words_list_indices = [self.vocab[word].index for word in words_to_use] | ||
# limited = self.syn0norm[words_list_indices] | ||
# Storing 'limited' might add a huge memory overhead so we avoid doing that | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please memory profile this code to provide foundation for this statement. |
||
|
||
dists = dot(self.syn0norm[words_list_indices], mean) | ||
result = [] | ||
|
||
best = matutils.argsort(dists, topn=topn + len(all_words), reverse=True) | ||
# ignore (don't return) words from the input | ||
for sim in best: | ||
index_to_return = words_list_indices[sim] | ||
if(index_to_return not in all_words): | ||
result.append((self.index2word[index_to_return], float(dists[sim]))) | ||
|
||
return result[:topn] | ||
|
||
def wmdistance(self, document1, document2): | ||
""" | ||
Compute the Word Mover's Distance between two documents. When using this | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use a single check and a single
raise ValueError
.The
most_similar
function doesn't take a list ofint
s so it should not be mentioned here