Skip to content

Commit

Permalink
clean up FastSS & logging
Browse files Browse the repository at this point in the history
  • Loading branch information
piskvorky committed May 17, 2021
1 parent de2ec13 commit 6c4abc5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 46 deletions.
88 changes: 49 additions & 39 deletions gensim/similarities/fastss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,42 @@

import struct
import itertools
import logging

ENCODING = 'utf-8'
DELIMITER = b'\x00'
logger = logging.getLogger(__name__)


def editdist(s1, s2):
"""Return the Levenshtein distance between two strings.
def editdist(s1, s2, maximum=None):
"""Return the Levenshtein distance between two strings, or maximum+1 if the distance is larger than `maximum`."""
# TODO: rewrite in C; big impact on query performance!
if s1 == s2:
return 0

>>> editdist('aiu', 'aie')
1
"""
matrix = {}
if len(s1) > len(s2):
s1, s2 = s2, s1

for i in range(len(s1) + 1):
matrix[(i, 0)] = i
for j in range(len(s2) + 1):
matrix[(0, j)] = j
if maximum is None:
maximum = len(s1)

for i in range(1, len(s1) + 1):
for j in range(1, len(s2) + 1):
if s1[i - 1] == s2[j - 1]:
matrix[(i, j)] = matrix[(i - 1, j - 1)]
else:
matrix[(i, j)] = min(
matrix[(i - 1, j)],
matrix[(i, j - 1)],
matrix[(i - 1, j - 1)]
) + 1
if len(s2) - len(s1) > maximum:
return maximum + 1

return matrix[(i, j)]
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
all_bad = i2 > maximum
for i1, c1 in enumerate(s1):
if c1 == c2:
val = distances[i1]
else:
val = 1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
distances_.append(val)
if all_bad and val <= maximum:
all_bad = False
if all_bad:
return maximum + 1
distances = distances_
return distances[-1]


def indexkeys(word, max_dist):
Expand All @@ -61,7 +67,7 @@ def indexkeys(word, max_dist):


def int2byte(i):
"""Encode a positive int (<= 256) into a 8-bit byte.
"""Encode an int <0, 255> into an 8-bit unsigned byte.
>>> int2byte(1)
b'\x01'
Expand All @@ -70,7 +76,7 @@ def int2byte(i):


def byte2int(b):
"""Decode a 8-bit byte into an integer.
"""Decode an 8-bit unsigned byte into an int.
>>> byte2int(b'\x01')
1
Expand All @@ -84,11 +90,7 @@ def set2bytes(s):
>>> set2byte({u'a', u'b', u'c'})
b'a\x00b\x00c'
"""
lis = []
for uword in sorted(s):
bword = uword.encode(ENCODING)
lis.append(bword)
return DELIMITER.join(lis)
return '\x00'.join(s).encode('utf8')


def bytes2set(b):
Expand All @@ -100,31 +102,39 @@ def bytes2set(b):
if not b:
return set()

lis = b.split(DELIMITER)
return set(bword.decode(ENCODING) for bword in lis)
return set(b.decode('utf8').split('\x00'))


class FastSS:
"""Open a FastSS index."""

def __init__(self, max_dist=2):
"""max_dist: the upper threshold of edit distance of works from the index."""
def __init__(self, words=None, max_dist=2):
"""
Create a FastSS index. The index will contain encoded variants of all
indexed words.
max_dist: maximum allowed edit distance of an indexed word to a query word. Keep
max_dist<=3 for sane performance.
"""
self.db = {}
self.max_dist = max_dist
if words:
for word in words:
self.add(word)

def __str__(self):
return "%s<max_dist=%s, db_size=%i>" % (self.__class__.__name__, self.max_dist, len(self.db), )

def __contains__(self, word):
bkey = word.encode(ENCODING)
bkey = word.encode('utf8')
if bkey in self.db:
return word in bytes2set(self.db[bkey])
return False

def add(self, word):
"""Add a string to the index."""
for key in indexkeys(word, self.max_dist):
bkey = key.encode(ENCODING)
bkey = key.encode('utf8')
wordset = {word}

if bkey in self.db:
Expand All @@ -146,13 +156,13 @@ def query(self, word, max_dist=None):
cands = set()

for key in indexkeys(word, max_dist):
bkey = key.encode(ENCODING)
bkey = key.encode('utf8')

if bkey in self.db:
cands.update(bytes2set(self.db[bkey]))

for cand in cands:
dist = editdist(word, cand)
dist = editdist(word, cand, max_dist)
if dist <= max_dist:
res[dist].append(cand)

Expand Down
7 changes: 2 additions & 5 deletions gensim/similarities/levenshtein.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,8 @@ def __init__(self, dictionary, alpha=1.8, beta=5.0, max_distance=1):
self.alpha = alpha
self.beta = beta
self.max_distance = max_distance

self.index = FastSS(self.max_distance)
for term in self.dictionary.values():
self.index.add(term)

logger.info("creating FastSS index from %s", dictionary)
self.index = FastSS(words=self.dictionary.values(), max_dist=max_distance)
super(LevenshteinSimilarityIndex, self).__init__()

def levsim(self, t1, t2, distance):
Expand Down
4 changes: 2 additions & 2 deletions gensim/similarities/termsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ def tfidf_sort_key(term_index):
return (-term_idf, term_index)

if tfidf is None:
logger.info("iterating over columns in dictionary order")
columns = sorted(dictionary.keys())
logger.info("iterating over %i columns in dictionary order", len(columns))
else:
assert max(tfidf.idfs) == matrix_order - 1
logger.info("iterating over columns in tf-idf order")
columns = sorted(tfidf.idfs.keys(), key=tfidf_sort_key)
logger.info("iterating over %i columns in tf-idf order", len(columns))

nonzero_counter_dtype = _shortest_uint_dtype(nonzero_limit)

Expand Down

0 comments on commit 6c4abc5

Please sign in to comment.