-
Notifications
You must be signed in to change notification settings - Fork 10
/
edit_utils_en.py
129 lines (104 loc) · 4.23 KB
/
edit_utils_en.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# @ hwang258@jh.edu
import re
from argparse import Namespace
def extract_words(sentence):
words = re.findall(r"\b[\w']+\b", sentence)
return words
def levenshtein_distance(word1, word2):
len1, len2 = len(word1), len(word2)
# Initialize a matrix to store the edit distances and operations
dp = [[(0, "") for _ in range(len2 + 1)] for _ in range(len1 + 1)]
# Initialize the first row and column
for i in range(len1 + 1):
dp[i][0] = (i, "d" * i)
for j in range(len2 + 1):
dp[0][j] = (j, "i" * j)
# Fill in the rest of the matrix
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
cost = 0 if word1[i - 1] == word2[j - 1] else 1
# Minimum of deletion, insertion, or substitution
deletion = dp[i - 1][j][0] + 1
insertion = dp[i][j - 1][0] + 1
substitution = dp[i - 1][j - 1][0] + cost
min_dist = min(deletion, insertion, substitution)
# Determine which operation led to the minimum distance
if min_dist == deletion:
operation = dp[i - 1][j][1] + "d"
elif min_dist == insertion:
operation = dp[i][j - 1][1] + "i"
else:
operation = dp[i - 1][j - 1][1] + ("s" if cost else "=")
dp[i][j] = (min_dist, operation)
# Backtrack to find the operations and positions
i, j = len1, len2
positions = []
while i > 0 and j > 0:
if dp[i][j][1][-1] == "d":
positions.append((i - 1, i, 'd'))
i -= 1
elif dp[i][j][1][-1] == "i":
positions.append((i, i, 'i'))
j -= 1
else:
if dp[i][j][1][-1] == "s":
positions.append((i - 1, i, 's'))
i -= 1
j -= 1
while i > 0:
positions.append((i - 1, i, 'd'))
i -= 1
while j > 0:
positions.append((i, i, 'i'))
j -= 1
return dp[len1][len2][0], dp[len1][len2][1], positions[::-1]
def extract_spans(positions, orig_len):
spans = []
if not positions:
return spans
current_start, current_end, current_op = positions[0]
for pos in positions[1:]:
start, end, op = pos
if op == current_op and (start == current_end or start == current_end + 1):
current_end = end
else:
spans.append((current_start, current_end))
current_start, current_end, current_op = start, end, op
spans.append((current_start, current_end))
# Handle insertions at the end
if spans[-1][0] >= orig_len:
spans[-1] = (orig_len, orig_len)
return spans
def combine_nearby_spans(spans):
if not spans:
return spans
combined_spans = [spans[0]]
for current_span in spans[1:]:
last_span = combined_spans[-1]
if last_span[1] + 1 >= current_span[0]: # Check if spans are adjacent or overlap
combined_spans[-1] = (last_span[0], max(last_span[1], current_span[1]))
else:
combined_spans.append(current_span)
return combined_spans
def parse_edit_en(orig_transcript, trgt_transcript):
word1 = extract_words(orig_transcript)
word2 = extract_words(trgt_transcript)
distance, operations, positions = levenshtein_distance(word1, word2)
spans = extract_spans(positions, len(word1))
spans = combine_nearby_spans(spans)
return operations, spans
def parse_tts_en(orig_transcript, trgt_transcript):
word1 = extract_words(orig_transcript)
word2 = extract_words(trgt_transcript)
distance, operations, positions = levenshtein_distance(word1, word2)
spans = extract_spans(positions, len(word1))
spans = [[spans[0][0], len(word1)]]
return spans
if __name__ == "__main__":
orig_transcript = "But when I had approached so near to them The common object, which the sense deceives, Lost not by distance any of its marks,"
trgt_transcript = "But when I saw the mirage of the lake in the distance, which the sense deceives, Lost not by distance any of its marks,"
operations, spans = parse_edit(orig_transcript, trgt_transcript)
print("Operations:", operations)
print("Spans:", spans)
spans_tts = parse_tts(orig_transcript, trgt_transcript)
print("TTS Spans:", spans_tts)