Skip to content

Commit

Permalink
update shibing624/chatglm3-6b-csc-chinese-lora model.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Nov 7, 2023
1 parent 0201ade commit 0ba3b5d
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 44 deletions.
57 changes: 34 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

- [Features](#Features)
- [Evaluation](#Evaluation)
- [Install](#install)
- [Usage](#usage)
- [Deep Model Usage](#deep-model-usage)
- [ContextDataset](#Dataset)
Expand All @@ -49,10 +48,10 @@
* [Kenlm模型](pycorrector/corrector.py):本项目基于Kenlm统计语言模型工具训练了中文NGram语言模型,结合规则方法、混淆集可以纠正中文拼写错误,方法速度快,扩展性强,效果一般
* [DeepContext模型](pycorrector/deepcontext):本项目基于PyTorch实现了用于文本纠错的DeepContext模型,该模型结构参考Stanford University的NLC模型,2014英文纠错比赛得第一名,效果一般
* [Seq2Seq模型](pycorrector/seq2seq):本项目基于PyTorch实现了用于中文文本纠错的ConvSeq2Seq模型,该模型在NLPCC-2018的中文语法纠错比赛中,使用单模型并取得第三名,可以并行训练,模型收敛快,效果一般
* [T5模型](pycorrector/t5)【推荐】:本项目基于PyTorch实现了用于中文文本纠错的T5模型,使用Langboat/mengzi-t5-base的预训练模型finetune中文纠错数据集,模型改造的潜力较大,效果好
* [T5模型](pycorrector/t5):本项目基于PyTorch实现了用于中文文本纠错的T5模型,使用Langboat/mengzi-t5-base的预训练模型finetune中文纠错数据集,模型改造的潜力较大,效果好
* [ERNIE_CSC模型](pycorrector/ernie_csc):本项目基于PaddlePaddle实现了用于中文文本纠错的ERNIE_CSC模型,模型在ERNIE-1.0上finetune,模型结构适配了中文拼写纠错任务,效果好
* [MacBERT模型](pycorrector/macbert)【推荐】:本项目基于PyTorch实现了用于中文文本纠错的MacBERT4CSC模型,模型加入了错误检测和纠正网络,适配中文拼写纠错任务,效果好
* [GPT模型](pycorrector/gpt)【推荐】:本项目基于PyTorch实现了用于中文文本纠错的ChatGLM/LLaMA模型,模型在中文CSC和语法纠错数据集上finetune,适配中文文本纠错任务,效果好
* [GPT模型](pycorrector/gpt):本项目基于PyTorch实现了用于中文文本纠错的ChatGLM/LLaMA模型,模型在中文CSC和语法纠错数据集上finetune,适配中文文本纠错任务,效果好

- 延展阅读:[中文文本纠错实践和原理解读](https://github.com/shibing624/pycorrector/blob/master/docs/correction_solution.md)
# Demo
Expand Down Expand Up @@ -143,8 +142,12 @@ pip install -r requirements.txt
```

# Usage
本项目的初衷之一是比对、调研各种中文文本纠错方法,抛砖引玉。

## 统计模型(kenlm)
项目实现了kenlm、macbert、seq2seq、 ernie_csc、T5、deepcontext、LLaMA等模型应用于文本纠错任务,各模型均可基于自有数据训练、预测。


## kenlm模型(统计模型)
### 中文拼写纠错

example: [examples/kenlm/demo.py](https://github.com/shibing624/pycorrector/blob/master/examples/kenlm/demo.py)
Expand Down Expand Up @@ -342,21 +345,8 @@ python -m pycorrector input.txt -o out.txt -n -d

> 输入文件:`input.txt`;输出文件:`out.txt `;关闭字粒度纠错;打印详细纠错信息;纠错结果以`\t`间隔
## Deep Model for Text Correction

本项目的初衷之一是比对、共享各种文本纠错方法,抛砖引玉的作用,如果对大家在文本纠错任务上有一点小小的启发就是我莫大的荣幸了。

实现了macbert、seq2seq、 ernie_csc、T5、deepcontext、GPT深度模型应用于文本纠错任务,各模型均可基于自有数据训练、预测。

- 安装依赖

```
pip install -r requirements-dev.txt
```

## 使用方法

### **MacBert4CSC模型[推荐]**
## MacBert4CSC模型

基于MacBERT改变网络结构的中文拼写纠错模型,模型已经开源在HuggingFace Models:https://huggingface.co/shibing624/macbert4csc-base-chinese

Expand All @@ -370,7 +360,7 @@ MacBERT4CSC 训练时用 detection 层和 correction 层的 loss 加权得到最
详细教程参考[examples/macbert/README.md](https://github.com/shibing624/pycorrector/blob/master/examples/macbert/README.md)


#### 使用pycorrector快速预测
#### pycorrector快速预测
example:[examples/macbert/demo.py](https://github.com/shibing624/pycorrector/blob/master/examples/macbert/demo.py)

```python
Expand Down Expand Up @@ -405,7 +395,7 @@ output:
#### 使用原生transformers库快速预测
[examples/macbert/README.md](https://github.com/shibing624/pycorrector/blob/master/examples/macbert/README.md)

### ErnieCSC模型
## ErnieCSC模型

基于ERNIE的中文拼写纠错模型,模型已经开源在[PaddleNLP](https://bj.bcebos.com/paddlenlp/taskflow/text_correction/csc-ernie-1.0/csc-ernie-1.0.pdparams)
模型网络结构:
Expand All @@ -416,7 +406,7 @@ output:



#### 使用pycorrector快速预测
#### pycorrector快速预测
example:[examples/ernie_csc/demo.py](https://github.com/shibing624/pycorrector/blob/master/examples/ernie_csc/demo.py)
```python
from pycorrector import ErnieCscCorrector
Expand All @@ -441,7 +431,7 @@ output:
```


### Bart模型
## Bart模型

基于SIGHAN+Wang271K中文纠错数据集训练的Bart4CSC模型,已经release到HuggingFace Models: https://huggingface.co/shibing624/bart4csc-base-chinese

Expand All @@ -467,6 +457,27 @@ output:

如果需要训练Bart模型,请参考 https://github.com/shibing624/textgen/blob/main/examples/seq2seq/training_bartseq2seq_zh_demo.py

## GPT模型
基于ChatGLM3、LLaMA、Baichuan、QWen等模型微调训练纠错模型,训练方法见[examples/gpt/README.md](https://github.com/shibing624/pycorrector/blob/master/examples/gpt/README.md)

在ChatGLM3-6B上SFT微调的纠错模型,已经release到HuggingFace Models: https://huggingface.co/shibing624/chatglm3-6b-csc-chinese-lora

#### pycorrector快速预测

example: [examples/gpt/demo.py](https://github.com/shibing624/pycorrector/blob/master/examples/gpt/demo.py)
```python
from pycorrector import GptCorrector
m = GptCorrector()
print(m.correct_batch(['今天新情很好', '你找到你最喜欢的工作,我也很高心。']))
```

output:
```shell
[{'source': '今天新情很好', 'target': '今天心情很好', 'errors': [('', '', 2)]},
{'source': '你找到你最喜欢的工作,我也很高心。', 'target': '你找到你最喜欢的工作,我也很高兴。', 'errors': [('', '', 15)]}]
```



# Dataset

Expand Down Expand Up @@ -567,7 +578,7 @@ BibTeX:
@misc{Xu_Pycorrector_Text_error,
title={Pycorrector: Text error correction tool},
author={Ming Xu},
year={2021},
year={2023},
howpublished={\url{https://github.com/shibing624/pycorrector}},
}
```
Expand Down
24 changes: 12 additions & 12 deletions pycorrector/deepcontext/deepcontext_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,20 @@
from loguru import logger

sys.path.append('../..')
from pycorrector.utils.text_utils import is_chinese_string
from pycorrector.utils.text_utils import is_chinese_char
from pycorrector.corrector import Corrector
from pycorrector.utils.tokenizer import split_text_into_sentences_by_length
from pycorrector.utils.tokenizer import split_text_into_sentences_by_symbol
from pycorrector.utils.get_file import get_file
from pycorrector.detector import USER_DATA_DIR
from pycorrector.deepcontext.deepcontext_model import DeepContextModel

pwd_path = os.path.abspath(os.path.dirname(__file__))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unk_tokens = [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤', '\t', '֍', '玕', '', '《', '》']
pretrained_deepcontext_models = {
# LM model
# LM model (45MB)
'deepcontext_lm.tar.gz':
'https://github.com/shibing624/pycorrector/releases/download/0.4.5/deepcontext_lm.tar.gz'
'https://github.com/shibing624/pycorrector/releases/download/0.4.6/deepcontext_lm.tar.gz'
}


Expand Down Expand Up @@ -57,23 +56,24 @@ def __init__(
)
t1 = time.time()
self.model = DeepContextModel(model_dir=model_dir, max_length=max_length)
self.model.load_model()
self.max_length = max_length
logger.debug('Loaded model: %s, spend: %.4f s.' % (model_dir, time.time() - t1))

def correct(self, sentence: str, **kwargs):
def correct(self, sentence: str, topk: int = 10, **kwargs):
"""Correct the Chinese sentence with deep context language model."""
details = []
text_new = ''
blocks = split_text_into_sentences_by_length(sentence, self.max_length)
blocks = split_text_into_sentences_by_symbol(sentence)
for blk, start_idx in blocks:
blk_new = ''
for idx, s in enumerate(blk):
# 处理中文错误
if is_chinese_string(s):
sentence_lst = list(blk_new + blk[idx:])
sentence_lst[idx] = self.model.mask
# 预测,默认取top10
predict_words = self.model.predict_mask_token(sentence_lst, idx, k=10)
if is_chinese_char(s):
tokens = list(blk_new + blk[idx:])
tokens[idx] = self.model.mask
# 预测
predict_words = self.model.predict_mask_token(tokens, idx, topk=topk)
top_tokens = []
for w, _ in predict_words:
top_tokens.append(w)
Expand Down
21 changes: 18 additions & 3 deletions pycorrector/deepcontext/deepcontext_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List

import numpy as np
import pandas as pd
import torch
from loguru import logger
from torch import optim
Expand Down Expand Up @@ -140,6 +141,12 @@ def train_model(

interval = 1e5
best_loss = 1e3
global_step = 0
training_progress_scores = {
"epoch": [],
"global_step": [],
"train_loss": [],
}
logger.info("train start...")
for epoch in range(num_epochs):
begin_time = time.time()
Expand All @@ -160,13 +167,15 @@ def train_model(
loss = model(sentence, target)
loss.backward()
optimizer.step()
global_step += 1
total_loss += loss.data.mean()

minibatch_size, sentence_length = target.size()
word_count += minibatch_size * sentence_length
accum_mean_loss = float(total_loss) / word_count if total_loss > 0.0 else 0.0
cur_mean_loss = (float(total_loss) - last_accum_loss) / (word_count - last_word_count)
cur_loss = cur_mean_loss

if word_count >= next_count:
now = time.time()
duration = now - cur_at
Expand All @@ -181,7 +190,13 @@ def train_model(
# find best model
is_best = cur_loss < best_loss
best_loss = min(cur_loss, best_loss)
logger.info('epoch: {}/{}, loss: {}, best_loss: {}'.format(epoch + 1, num_epochs, cur_loss, best_loss))
logger.info('epoch: {}/{}, global_step: {}, loss: {}, best_loss: {}'.format(
epoch + 1, num_epochs, global_step, cur_loss, best_loss))
training_progress_scores["epoch"].append(epoch + 1)
training_progress_scores["global_step"].append(global_step)
training_progress_scores["train_loss"].append(cur_loss)
report = pd.DataFrame(training_progress_scores)
report.to_csv(os.path.join(self.model_dir, "training_progress_scores.csv"), index=False)
if is_best:
self.save_model(model_dir=self.model_dir, model=model, optimizer=optimizer)
logger.info('save new model: {}'.format(epoch + 1, self.model_dir))
Expand All @@ -198,7 +213,7 @@ def save_model(self, model_dir=None, model=None, optimizer=None):
if optimizer:
torch.save(optimizer.state_dict(), self.optimizer_file)

def predict_mask_token(self, tokens: List[str], mask_index: int = 0, k: int = 10):
def predict_mask_token(self, tokens: List[str], mask_index: int = 0, topk: int = 10):
if not self.model:
self.load_model()
unk_token = self.config_dict['unk_token']
Expand All @@ -211,7 +226,7 @@ def predict_mask_token(self, tokens: List[str], mask_index: int = 0, k: int = 10
tokens = [sos_token] + tokens + [eos_token]
indexed_sentence = [self.stoi[token] if token in self.stoi else self.stoi[unk_token] for token in tokens]
input_tokens = torch.tensor(indexed_sentence, dtype=torch.long, device=device).unsqueeze(0)
topv, topi = self.model.run_inference(input_tokens, target=None, target_pos=mask_index, k=k)
topv, topi = self.model.run_inference(input_tokens, target=None, target_pos=mask_index, topk=topk)
for value, key in zip(topv, topi):
score = value.item()
word = self.itos[key.item()]
Expand Down
8 changes: 4 additions & 4 deletions pycorrector/deepcontext/deepcontext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,10 @@ def init_hidden(self, batch_size):
return (weight.new_zeros(self.n_layers, batch_size, self.hidden_size),
weight.new_zeros(self.n_layers, batch_size, self.hidden_size))

def run_inference(self, input_tokens, target, target_pos, k=10):
def run_inference(self, input_tokens, target, target_pos, topk=10):
context_vector = self.forward(input_tokens, target=None, target_pos=target_pos)
if target is None:
topv, topi = ((self.criterion.W.weight * context_vector).sum(dim=1)).data.topk(k)
topv, topi = ((self.criterion.W.weight * context_vector).sum(dim=1)).data.topk(topk)
return topv, topi
else:
context_vector /= torch.norm(context_vector, p=2)
Expand Down Expand Up @@ -303,8 +303,8 @@ def load_word_dict(save_path):
items = line.split('\t')
try:
dict_data[items[0]] = int(items[1])
except IndexError:
logger.warning(f"IndexError: {line}")
except Exception as e:
logger.warning(f"Exception: {e}, {line}")
return dict_data


Expand Down
4 changes: 2 additions & 2 deletions pycorrector/gpt/gpt_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
class GptCorrector(GptModel):
def __init__(
self,
model_name_or_path: str = "shibing624/chatglm3-6b-csc-chinese-merged",
model_name_or_path: str = "THUDM/chatglm3-6b",
model_type: str = 'chatglm',
peft_name: Optional[str] = None,
peft_name: Optional[str] = "shibing624/chatglm3-6b-csc-chinese-lora",
**kwargs,
):
t1 = time.time()
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
jieba
pypinyin
numpy
pandas
six
loguru
kenlm

0 comments on commit 0ba3b5d

Please sign in to comment.