Skip to content

Commit

Permalink
fix #534
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Dec 13, 2024
1 parent 4a5954c commit 31f7993
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions pycorrector/proper_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@
import os
from codecs import open
from typing import List

import pypinyin
from loguru import logger

from pycorrector.utils.math_utils import edit_distance
from pycorrector.utils.ngram_util import NgramUtil
from pycorrector.utils.text_utils import is_chinese_char
from pycorrector.utils.tokenizer import segment, split_text_into_sentences_by_symbol
from collections import defaultdict

pwd_path = os.path.abspath(os.path.dirname(__file__))

# 五笔笔画字典
stroke_path = os.path.join(pwd_path, 'data/stroke.txt')
# 专名词典,包括成语、俗语、专业领域词等 format: 词语
proper_name_path = os.path.join(pwd_path, 'data/proper_name.txt')
# 专名词典,包括成语、俗语、专业领域词等 format: 词语, 可以自定义
default_proper_name_path = os.path.join(pwd_path, 'data/proper_name.txt')


def load_set_file(path):
Expand Down Expand Up @@ -60,17 +59,45 @@ def load_dict_file(path):
return result


class TrieNode:
def __init__(self):
self.children = defaultdict(TrieNode)
self.is_end_of_word = False


class Trie:
def __init__(self):
self.root = TrieNode()

def insert(self, word):
node = self.root
for char in word:
node = node.children[char]
node.is_end_of_word = True

def search(self, word):
node = self.root
for char in word:
if char not in node.children:
return False
node = node.children[char]
return node.is_end_of_word


class ProperCorrector:
def __init__(
self,
proper_name_path=proper_name_path,
proper_name_path=default_proper_name_path,
stroke_path=stroke_path,
):
self.name = 'ProperCorrector'
# proper name, 专名词典,包括成语、俗语、专业领域词等 format: 词语
self.proper_names = load_set_file(proper_name_path)
# stroke, 笔划字典 format: 字:笔划,如:万,笔划是横(h),折(z),撇(p),组合起来是:hzp
self.stroke_dict = load_dict_file(stroke_path)
self.trie = Trie()
for name in self.proper_names:
self.trie.insert(name)

def get_stroke(self, char):
"""
Expand All @@ -95,6 +122,7 @@ def is_near_stroke_char(self, char1, char2, stroke_threshold=0.8):
def get_char_stroke_similarity_score(self, char1, char2):
"""
获取字符的字形相似度
Args:
char1:
char2:
Expand Down Expand Up @@ -253,6 +281,8 @@ def correct(
# 词长度过滤
ngrams = [i for i in ngrams if min_word_length <= len(i) <= max_word_length]
for cur_item in ngrams:
if self.trie.search(cur_item):
continue
for name in self.proper_names:
if self.get_word_similarity_score(cur_item, name) > sim_threshold:
if cur_item != name:
Expand Down

0 comments on commit 31f7993

Please sign in to comment.