Skip to content

Commit

Permalink
Return edit ops as well
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 22, 2023
1 parent 5f764d9 commit faa7a56
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions supar/structs/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def mst(scores: torch.Tensor, mask: torch.BoolTensor, multiroot: bool = False) -

def levenshtein(x: Iterable, y: Iterable, costs: Tuple = (1, 1, 1), align: bool = False) -> int:
"""
Calculates the Levenshtein edit-distance between two sequencess,
which refers to the total number of characters that must be
Calculates the Levenshtein edit-distance between two sequences,
which refers to the total number of tokens that must be
substituted, deleted or inserted to transform `x` into `y`.
The code is revised from `nltk`_ and `wiki`_'s implementations.
Expand All @@ -231,53 +231,61 @@ def levenshtein(x: Iterable, y: Iterable, costs: Tuple = (1, 1, 1), align: bool
costs (Tuple):
Edit costs for substitution, deletion or insertion. Default: `(1, 1, 1)`.
align (bool):
Whether to return the alignments based on the minimum Levenshtein edit-distance. Default: ``False``.
Whether to return the alignments based on the minimum Levenshtein edit-distance.
If ``True``, returns a list of tuples representing the alignment position as well as the edit operation.
The order of edits are `KEEP`, `SUBSTITUTION`, `DELETION` and `INSERTION` respectively.
For example, `(i, j, 0)` means keeps the `i`th token to the `j`th position and so forth.
Default: ``False``.
Examples:
>>> from supar.structs.fn import levenshtein
>>> levenshtein('intention', 'execution', align=True)
(5, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)])
>>> levenshtein('intention', 'execution')
5
>>> levenshtein('rain', 'brainy', align=True)
(2, [(0, 1, 3), (1, 2, 0), (2, 3, 0), (3, 4, 0), (4, 5, 0), (4, 6, 3)])
.. _nltk:
https://github.com/nltk/nltk/blob/develop/nltk/metrics/distance.py
https://github.com/nltk/nltk/blob/develop/nltk/metrics/dist.py
.. _wiki:
https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance
"""

# set up a 2-D array
len1, len2 = len(x), len(y)
lev = [list(range(len2 + 1))] + [[i] + [0] * len2 for i in range(1, len1 + 1)]
alg = [[2] * (len2 + 1)] + [[1] + [-1] * len2 for _ in range(1, len1 + 1)] if align else None
dists = [list(range(len2 + 1))] + [[i] + [0] * len2 for i in range(1, len1 + 1)]
edits = [[0] + [3] * len2] + [[2] + [-1] * len2 for _ in range(1, len1 + 1)] if align else None

# iterate over the array
# i and j start from 1 and not 0 to stay close to the wikipedia pseudo-code
# see https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
# substitution / keep
s = lev[i - 1][j - 1] + (costs[0] if x[i - 1] != y[j - 1] else 0)
# keep / substitution
s = dists[i - 1][j - 1] + (costs[0] if x[i - 1] != y[j - 1] else 0)
# deletion
a = lev[i - 1][j] + costs[1]
a = dists[i - 1][j] + costs[1]
# insertion
b = lev[i][j - 1] + costs[2]
b = dists[i][j - 1] + costs[2]

edit, lev[i][j] = min(enumerate((s, a, b)), key=operator.itemgetter(1))
edit, dists[i][j] = min(enumerate((s, a, b), 1), key=operator.itemgetter(1))
if align:
alg[i][j] = edit
distance = lev[-1][-1]
edits[i][j] = edit if edit != 1 else int(x[i - 1] != y[j - 1])

dist = dists[-1][-1]
if align:
i, j = len1, len2
alignments = [(i, j)]
alignments = []
while (i, j) != (0, 0):
alignments.append((i, j, edits[i][j]))
grids = [
(i - 1, j - 1), # keep
(i - 1, j - 1), # substitution
(i - 1, j), # deletion
(i, j - 1), # insertion
]
i, j = grids[alg[i][j]]
alignments.append((i, j))
i, j = grids[edits[i][j]]
alignments = list(reversed(alignments))
return (distance, alignments) if align else distance
return (dist, alignments) if align else dist


class Logsumexp(Function):
Expand Down

0 comments on commit faa7a56

Please sign in to comment.