diff --git a/supar/structs/fn.py b/supar/structs/fn.py index 1a0af2ad..73c03f42 100644 --- a/supar/structs/fn.py +++ b/supar/structs/fn.py @@ -4,10 +4,11 @@ from typing import Iterable, Tuple, Union import torch -from supar.utils.common import INF, MIN -from supar.utils.fn import pad from torch.autograd import Function +from supar.utils.common import MIN +from supar.utils.fn import pad + def tarjan(sequence: Iterable[int]) -> Iterable[int]: r""" @@ -216,30 +217,24 @@ def mst(scores: torch.Tensor, mask: torch.BoolTensor, multiroot: bool = False) - return pad(preds, total_length=seq_len).to(mask.device) -def levenshtein(x: Iterable, y: Iterable, align: bool = False) -> int: +def levenshtein(x: Iterable, y: Iterable, costs: Tuple = (1, 1, 1), align: bool = False) -> int: """ - Calculates the Levenshtein edit-distance between two sequences. - The edit distance is the number of characters that need to be - substituted, inserted, or deleted, to transform `x` into `y`. - - For example, transforming "rain" to "shine" requires three steps, - consisting of two substitutions and one insertion: - "rain" -> "sain" -> "shin" -> "shine". - These operations could have been done in other orders, but at least three steps are needed. - - Allows specifying the cost of substitution edits (e.g., "a" -> "b"), - because sometimes it makes sense to assign greater penalties to substitutions. + Calculates the Levenshtein edit-distance between two sequencess, + which refers to the total number of characters that must be + substituted, deleted or inserted to transform `x` into `y`. The code is revised from `nltk`_ and `wiki`_'s implementations. Args: x/y (Iterable): The sequences to be analysed. + costs (Tuple): + Edit costs for substitution, deletion or insertion. Default: `(1, 1, 1)`. align (bool): Whether to return the alignments based on the minimum Levenshtein edit-distance. Default: ``False``. Examples: - >>> from supar.structs.utils.fn import levenshtein + >>> from supar.structs.fn import levenshtein >>> levenshtein('intention', 'execution', align=True) (5, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7), (8, 8), (9, 9)]) @@ -252,32 +247,34 @@ def levenshtein(x: Iterable, y: Iterable, align: bool = False) -> int: # set up a 2-D array len1, len2 = len(x), len(y) lev = [list(range(len2 + 1))] + [[i] + [0] * len2 for i in range(1, len1 + 1)] + alg = [[2] * (len2 + 1)] + [[1] + [-1] * len2 for _ in range(1, len1 + 1)] if align else None # iterate over the array # i and j start from 1 and not 0 to stay close to the wikipedia pseudo-code # see https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance for i in range(1, len1 + 1): for j in range(1, len2 + 1): - # substitution - s = lev[i - 1][j - 1] + (x[i - 1] != y[j - 1]) + # substitution / keep + s = lev[i - 1][j - 1] + (costs[0] if x[i - 1] != y[j - 1] else 0) # deletion - a = lev[i - 1][j] + 1 + a = lev[i - 1][j] + costs[1] # insertion - b = lev[i][j - 1] + 1 + b = lev[i][j - 1] + costs[2] - lev[i][j] = min(s, a, b) + edit, lev[i][j] = min(enumerate((s, a, b)), key=operator.itemgetter(1)) + if align: + alg[i][j] = edit distance = lev[-1][-1] if align: i, j = len1, len2 alignments = [(i, j)] while (i, j) != (0, 0): - directions = [ + grids = [ (i - 1, j - 1), # substitution (i - 1, j), # deletion (i, j - 1), # insertion ] - direction_costs = ((lev[i][j] if (i >= 0 and j >= 0) else INF, (i, j)) for i, j in directions) - _, (i, j) = min(direction_costs, key=operator.itemgetter(0)) + i, j = grids[alg[i][j]] alignments.append((i, j)) alignments = list(reversed(alignments)) return (distance, alignments) if align else distance