Skip to content

Commit

Permalink
Fix backtrace errors in levenshtein
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 22, 2023
1 parent ae5faee commit 5f764d9
Showing 1 changed file with 20 additions and 23 deletions.
43 changes: 20 additions & 23 deletions supar/structs/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Iterable, Tuple, Union

import torch
from supar.utils.common import INF, MIN
from supar.utils.fn import pad
from torch.autograd import Function

from supar.utils.common import MIN
from supar.utils.fn import pad


def tarjan(sequence: Iterable[int]) -> Iterable[int]:
r"""
Expand Down Expand Up @@ -216,30 +217,24 @@ def mst(scores: torch.Tensor, mask: torch.BoolTensor, multiroot: bool = False) -
return pad(preds, total_length=seq_len).to(mask.device)


def levenshtein(x: Iterable, y: Iterable, align: bool = False) -> int:
def levenshtein(x: Iterable, y: Iterable, costs: Tuple = (1, 1, 1), align: bool = False) -> int:
"""
Calculates the Levenshtein edit-distance between two sequences.
The edit distance is the number of characters that need to be
substituted, inserted, or deleted, to transform `x` into `y`.
For example, transforming "rain" to "shine" requires three steps,
consisting of two substitutions and one insertion:
"rain" -> "sain" -> "shin" -> "shine".
These operations could have been done in other orders, but at least three steps are needed.
Allows specifying the cost of substitution edits (e.g., "a" -> "b"),
because sometimes it makes sense to assign greater penalties to substitutions.
Calculates the Levenshtein edit-distance between two sequencess,
which refers to the total number of characters that must be
substituted, deleted or inserted to transform `x` into `y`.
The code is revised from `nltk`_ and `wiki`_'s implementations.
Args:
x/y (Iterable):
The sequences to be analysed.
costs (Tuple):
Edit costs for substitution, deletion or insertion. Default: `(1, 1, 1)`.
align (bool):
Whether to return the alignments based on the minimum Levenshtein edit-distance. Default: ``False``.
Examples:
>>> from supar.structs.utils.fn import levenshtein
>>> from supar.structs.fn import levenshtein
>>> levenshtein('intention', 'execution', align=True)
(5, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)])
Expand All @@ -252,32 +247,34 @@ def levenshtein(x: Iterable, y: Iterable, align: bool = False) -> int:
# set up a 2-D array
len1, len2 = len(x), len(y)
lev = [list(range(len2 + 1))] + [[i] + [0] * len2 for i in range(1, len1 + 1)]
alg = [[2] * (len2 + 1)] + [[1] + [-1] * len2 for _ in range(1, len1 + 1)] if align else None

# iterate over the array
# i and j start from 1 and not 0 to stay close to the wikipedia pseudo-code
# see https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
# substitution
s = lev[i - 1][j - 1] + (x[i - 1] != y[j - 1])
# substitution / keep
s = lev[i - 1][j - 1] + (costs[0] if x[i - 1] != y[j - 1] else 0)
# deletion
a = lev[i - 1][j] + 1
a = lev[i - 1][j] + costs[1]
# insertion
b = lev[i][j - 1] + 1
b = lev[i][j - 1] + costs[2]

lev[i][j] = min(s, a, b)
edit, lev[i][j] = min(enumerate((s, a, b)), key=operator.itemgetter(1))
if align:
alg[i][j] = edit
distance = lev[-1][-1]
if align:
i, j = len1, len2
alignments = [(i, j)]
while (i, j) != (0, 0):
directions = [
grids = [
(i - 1, j - 1), # substitution
(i - 1, j), # deletion
(i, j - 1), # insertion
]
direction_costs = ((lev[i][j] if (i >= 0 and j >= 0) else INF, (i, j)) for i, j in directions)
_, (i, j) = min(direction_costs, key=operator.itemgetter(0))
i, j = grids[alg[i][j]]
alignments.append((i, j))
alignments = list(reversed(alignments))
return (distance, alignments) if align else distance
Expand Down

0 comments on commit 5f764d9

Please sign in to comment.