From e788b38dabbb81cea282b049179632f99165c2e8 Mon Sep 17 00:00:00 2001 From: Ayub Date: Fri, 22 Sep 2023 13:41:28 +0330 Subject: [PATCH] fix performance issue in join_abbreviations --- hazm/word_tokenizer.py | 43 ++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/hazm/word_tokenizer.py b/hazm/word_tokenizer.py index 8e56dcf3..f579fd78 100644 --- a/hazm/word_tokenizer.py +++ b/hazm/word_tokenizer.py @@ -243,10 +243,19 @@ def __init__( + ["ن" + bon + "ه" for bon in self.bons], ) - abbreviations_file = Path(abbreviations) + if (join_abbreviations): + abbreviations_file = Path(abbreviations) - with abbreviations_file.open("r", encoding="utf-8") as f: - self.abbreviations = [line.strip() for line in f] + with abbreviations_file.open("r", encoding="utf-8") as f: + lines = [line.strip() for line in f] + sorted_lines= sorted(lines, key=len, reverse=True) + + abbrs = [] + for abbr in sorted_lines: + arr = [item for item in re.split(r'([.()])', abbr) if item] + abbrs.append(arr) + + self.abbreviations = abbrs @@ -361,18 +370,24 @@ def join_abbreviations(self: "WordTokenizer", tokens: List[str]) -> List[str]: """ result = [] i = 0 - abbreviations = self.abbreviations + while i < len(tokens): - longest = None - for j in range(i, len(tokens)): - candidate = "".join(tokens[i:j+1]) - if candidate in abbreviations: - longest = candidate - longest_idx = j - if longest: - result.append(abbreviations[abbreviations.index(longest)]) - i = longest_idx + 1 - else: + found = False + + for abbr in self.abbreviations: + if tokens[i:i + len(abbr)] == abbr: + result.append("".join(abbr)) + i += len(abbr) + found = True + break + + if not found: result.append(tokens[i]) i += 1 + return result + + + + +