Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions knlp/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# -----------------------------------------------------------------------#
import os

GIT_DATA_URL = "https://github.com/global-nlp/knlp_data/archive/refs/heads/main.zip" # /knlp/data 数据下载位置
GIT_MODEL_URL = "https://github.com/global-nlp/knlp_model/archive/refs/heads/main.zip" # /knlp/model 数据下载位置
KNLP_PATH = os.path.dirname(os.path.realpath(__file__)) + "/../.."
sentence_delimiters = ['?', '!', ';', '?', '!', '。', ';', '……', '…', '\n']
allow_speech_tags = ['an', 'i', 'j', 'l', 'n', 'nr', 'nrfg', 'ns', 'nt', 'nz', 't', 'v', 'vd', 'vn', 'eng']
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from knlp.common.constant import sentence_delimiters, allow_speech_tags
from knlp.information_extract.keywords_extraction.textrank_keyword import TextRank4Keyword
from knlp.utils.util import AttrDict, get_default_stop_words_file
from knlp.utils.util import AttrDict, get_default_stop_words_file, get_pytest_data_file


class TextRank4Sentence(TextRank4Keyword):
Expand Down Expand Up @@ -194,7 +194,7 @@ def sort_sentence_by_keyword(self, num=6, window=2, word_min_len=1, page_rank_co


if __name__ == '__main__':
with open("knlp/data/pytest_data.txt", encoding='utf-8') as f:
with open(get_pytest_data_file(), encoding='utf-8') as f:
text = f.read()

tr4s = TextRank4Sentence()
Expand Down
3 changes: 2 additions & 1 deletion knlp/samples/IE_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


from knlp.information_extract.keywords_extraction import TextRank4Keyword, TextRank4Sentence
from knlp.utils.util import get_pytest_data_file


def get_keyword(text, window=5, num=20, word_min_len=2, need_key_phrase=True):
Expand Down Expand Up @@ -80,7 +81,7 @@ def get_key_sentences_by_keyword(text):


if __name__ == '__main__':
with open("knlp/data/pytest_data.txt", encoding='utf-8') as f:
with open(get_pytest_data_file(), encoding='utf-8') as f:
text = f.read()
print(get_key_sentences(text))
print(get_key_sentences_by_keyword(text))
Expand Down
7 changes: 3 additions & 4 deletions knlp/samples/crf_sample.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# !/usr/bin/python
# -*- coding:UTF-8 -*-
from knlp.common.constant import KNLP_PATH
from knlp.seq_labeling.crf.inference import Inference
from knlp.seq_labeling.crf.train import Train
from knlp.utils.util import get_model_crf_hanzi_file, get_data_hanzi_segment_file

# init trainer and inferencer
crf_inferencer = Inference()
Expand Down Expand Up @@ -46,9 +46,8 @@ def load_and_test_inference(model_save_file, sentence):


if __name__ == '__main__':
training_data_path = KNLP_PATH + "/knlp/data/hanzi_segment.txt"
model_save_file = KNLP_PATH + "/knlp/model/crf/hanzi_segment.pkl"
crf_train(training_data_path=training_data_path, model_save_file=model_save_file)
model_save_file = get_model_crf_hanzi_file()
crf_train(training_data_path=get_data_hanzi_segment_file(), model_save_file=model_save_file)

sentence = "从明天起,做一个幸福的人,关心粮食与蔬菜。"
load_and_test_inference(model_save_file=model_save_file, sentence=sentence)
5 changes: 2 additions & 3 deletions knlp/samples/hmm_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from knlp.common.constant import KNLP_PATH
from knlp.seq_labeling.hmm.inference import Inference
from knlp.seq_labeling.hmm.train import Train
from knlp.utils.util import get_pku_vocab_train_file, get_pku_hmm_train_file

# init trainer and inferencer
hmm_inferencer = Inference()
Expand Down Expand Up @@ -67,9 +68,7 @@ def test_inference(sentence):


if __name__ == '__main__':
vocab_set_path = KNLP_PATH + "/knlp/data/seg_data/train/pku_vocab.txt"
training_data_path = KNLP_PATH + "/knlp/data/seg_data/train/pku_hmm_training_data_sample.txt"
model_save_path = KNLP_PATH + "/knlp/model/hmm/"
hmm_train(vocab_set_path=vocab_set_path, training_data_path=training_data_path, model_save_path=model_save_path)
hmm_train(vocab_set_path=get_pku_vocab_train_file(), training_data_path=get_pku_hmm_train_file(), model_save_path=model_save_path)
hmm_inference_load_model(model_save_path=model_save_path)
print(test_inference("大家好,我是你们的好朋友"))
5 changes: 2 additions & 3 deletions knlp/seq_labeling/crf/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re

from knlp.seq_labeling.crf.crf import CRFModel
from knlp.common.constant import KNLP_PATH
from knlp.utils.util import get_model_crf_hanzi_file


class Inference:
Expand Down Expand Up @@ -62,11 +62,10 @@ def cut(self, sentence1, sentence2):

if __name__ == "__main__":
test = Inference()
CRF_MODEL_PATH = KNLP_PATH + "/knlp/model/crf/hanzi_segment.pkl"

print("读取数据...")
to_be_pred = "冬天到了,春天还会远吗?"

test.spilt_predict(to_be_pred, CRF_MODEL_PATH)
test.spilt_predict(to_be_pred, get_model_crf_hanzi_file())
print("POS结果:" + str(test.label_prediction))
print("模型预测结果:" + str(test.out_sentence))
8 changes: 4 additions & 4 deletions knlp/seq_labeling/crf/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys

from knlp.seq_labeling.crf.crf import CRFModel
from knlp.common.constant import KNLP_PATH
from knlp.utils.util import get_model_crf_hanzi_file, get_data_hanzi_segment_file


class Train:
Expand All @@ -24,7 +24,7 @@ def __init__(self, data_path=None):
self.init_variable(training_data_path=data_path)

def init_variable(self, training_data_path=None):
self.training_data_path = KNLP_PATH + "/knlp/data/hanzi_segment.txt" if not training_data_path else training_data_path
self.training_data_path = get_data_hanzi_segment_file() if not training_data_path else training_data_path

with open(self.training_data_path, encoding='utf-8') as f:
self.training_data = f.readlines()
Expand Down Expand Up @@ -65,7 +65,7 @@ def save_model(self, model_save_path):
if __name__ == "__main__":

args = sys.argv
train_data_path = KNLP_PATH + "/knlp/data/hanzi_segment.txt"
train_data_path = get_data_hanzi_segment_file()

if len(args) > 1:
train_data_path = args[1]
Expand All @@ -77,6 +77,6 @@ def save_model(self, model_save_path):

print("正在保存模型...")

CRF_trainer.save_model(KNLP_PATH + "/knlp/model/crf/hanzi_segment.pkl")
CRF_trainer.save_model(get_model_crf_hanzi_file())

print("训练完成。")
8 changes: 4 additions & 4 deletions knlp/seq_labeling/data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Created Time: 2021-03-20
# Description:
# -----------------------------------------------------------------------#
from knlp.utils.util import funtion_time_cost
from knlp.utils.util import funtion_time_cost, get_pku_vocab_train_file, check_file


class DataHelper:
Expand Down Expand Up @@ -99,10 +99,11 @@ def generate_vocab(cls, input_file, output_file):


if __name__ == '__main__':
from knlp.common.constant import KNLP_PATH
from knlp.common.constant import KNLP_PATH, GIT_MODEL_URL

# make pku training data
# input_file = KNLP_PATH + "/knlp/data/seg_data/icwb2-data/training/pku_training.utf8"
# check_file(KNLP_PATH + "/knlp/data/seg_data/train", GIT_DATA_URL)
# output_file = KNLP_PATH + "/knlp/data/seg_data/train/pku_hmm_training_data.txt"
# DataHelper.make_smbe_data(input_file, output_file)
#
Expand All @@ -113,5 +114,4 @@ def generate_vocab(cls, input_file, output_file):

# make pku vocab data
input_file = KNLP_PATH + "/knlp/data/seg_data/icwb2-data/testing/pku_test.utf8"
output_file = KNLP_PATH + "/knlp/data/seg_data/train/pku_vocab.txt"
DataHelper.generate_vocab(input_file, output_file)
DataHelper.generate_vocab(input_file, get_pku_vocab_train_file())
4 changes: 3 additions & 1 deletion knlp/seq_labeling/hmm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import json
import re

from knlp.common.constant import KNLP_PATH
from knlp.common.constant import KNLP_PATH, GIT_MODEL_URL
from knlp.utils.util import check_file


class Inference:
Expand Down Expand Up @@ -45,6 +46,7 @@ def helper(file_path, save_format="json"):
with open(file_path, encoding='utf-8') as f:
return json.load(f)

check_file(KNLP_PATH + "/knlp/model/hmm/seg", GIT_MODEL_URL)
state_set = KNLP_PATH + "/knlp/model/hmm/seg/state_set.json" if not state_set_save_path else state_set_save_path + "/state_set.json"
transition_pro = KNLP_PATH + "/knlp/model/hmm/seg/transition_pro.json" if not transition_pro_save_path else transition_pro_save_path + "/transition_pro.json"
emission_pro = KNLP_PATH + "/knlp/model/hmm/seg/emission_pro.json" if not emission_pro_save_path else emission_pro_save_path + "/emission_pro.json"
Expand Down
8 changes: 5 additions & 3 deletions knlp/seq_labeling/hmm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
import sys
from collections import defaultdict

from knlp.common.constant import KNLP_PATH
from knlp.common.constant import KNLP_PATH, GIT_MODEL_URL
from knlp.utils.util import get_pku_vocab_train_file, get_pku_hmm_train_file, check_file


class Train:
Expand Down Expand Up @@ -99,8 +100,8 @@ def __init__(self, vocab_set_path=None, training_data_path=None, test_data_path=
test_data_path=test_data_path)

def init_variable(self, vocab_set_path=None, training_data_path=None, test_data_path=None):
self.vocab_set_path = KNLP_PATH + "/knlp/data/seg_data/train/pku_vocab.txt" if not vocab_set_path else vocab_set_path
self.training_data_path = KNLP_PATH + "/knlp/data/seg_data/train/pku_hmm_training_data.txt" if not training_data_path else training_data_path
self.vocab_set_path = get_pku_vocab_train_file() if not vocab_set_path else vocab_set_path
self.training_data_path = get_pku_hmm_train_file() if not training_data_path else training_data_path
# self.test_data_path = KNLP_PATH + "/knlp/data/seg_data/train/pku_hmm_test_data.txt" if not test_data_path else test_data_path
with open(self.vocab_set_path, encoding='utf-8') as f:
self.vocab_data = f.readlines()
Expand Down Expand Up @@ -225,6 +226,7 @@ def build_model(self, state_set_save_path=None, transition_pro_save_path=None, e

Returns:
"""
check_file(KNLP_PATH + "/knlp/model/hmm/seg", GIT_MODEL_URL)
state_set = KNLP_PATH + "/knlp/model/hmm/seg/state_set.json" if not state_set_save_path else state_set_save_path + "/state_set.json"
transition_pro = KNLP_PATH + "/knlp/model/hmm/seg/transition_pro.json" if not transition_pro_save_path else transition_pro_save_path + "/transition_pro.json"
emission_pro = KNLP_PATH + "/knlp/model/hmm/seg/emission_pro.json" if not emission_pro_save_path else emission_pro_save_path + "/emission_pro.json"
Expand Down
7 changes: 4 additions & 3 deletions knlp/seq_labeling/pinyin_input_method/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import json
import math

from knlp.common.constant import KNLP_PATH
from knlp.common.constant import KNLP_PATH, GIT_MODEL_URL
from knlp.seq_labeling.crf.crf import CRFModel
from knlp.utils.util import get_model_crf_pinyin_file, check_file


class Inference:
Expand All @@ -25,6 +26,7 @@ def helper(file_path):
with open(file_path) as f:
return json.load(f)

check_file(KNLP_PATH + "/knlp/model/hmm/pinyin_input_data", GIT_MODEL_URL)
self.pinyin_hanzi = helper(KNLP_PATH + '/knlp/model/hmm/pinyin_input_data/pinyin_hanzi.json')
self.start_state = helper(KNLP_PATH + '/knlp/model/hmm/pinyin_input_data/start_state.json')
self.emission_pro = helper(KNLP_PATH + '/knlp/model/hmm/pinyin_input_data/emission_pro.json')
Expand Down Expand Up @@ -164,12 +166,11 @@ def cut(self, sentence1, sentence2):
test = Inference()

CRF = CRFModel()
CRF_MODEL_PATH = KNLP_PATH + "/knlp/model/crf/pinyin.pkl"

print("读取数据...")
to_be_pred = "dongtianlailechuntianyejiangdaolai"

test.spilt_predict(to_be_pred, CRF_MODEL_PATH)
test.spilt_predict(to_be_pred, get_model_crf_pinyin_file())
print("POS结果:" + str(test.label_prediction))
print("拼音分割结果:" + str(test.out_sentence))

Expand Down
8 changes: 3 additions & 5 deletions knlp/seq_labeling/pinyin_input_method/pinyin_segment_train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*-coding:utf-8-*-
import pickle

from knlp.common.constant import KNLP_PATH
from knlp.seq_labeling.crf.crf import CRFModel
from knlp.utils.util import get_model_crf_pinyin_file, get_data_pinyin_segment_file


class Train:
Expand Down Expand Up @@ -47,15 +47,13 @@ def save_model(self, file_name):


if __name__ == "__main__":
train_data_path = KNLP_PATH + "/knlp/data/pinyin_segment.txt"

print("正在读入数据进行训练...")

CRF_trainer = Train(train_data_path)
CRF_trainer = Train(get_data_pinyin_segment_file())
CRF_trainer.load_and_train()

print("正在保存模型...")

CRF_trainer.save_model(KNLP_PATH + "/knlp/model/crf/pinyin.pkl")
CRF_trainer.save_model(get_model_crf_pinyin_file())

print("训练完成。")
Loading