-
Notifications
You must be signed in to change notification settings - Fork 904
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
151 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
## 此处存放ERNIE预训练模型: | ||
pytorch_model.bin | ||
bert_config.json | ||
vocab.txt | ||
|
||
## 下载地址: | ||
http://image.nghuyong.top/ERNIE.zip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,80 @@ | ||
# Bert-Chinese-Text-Classification-Pytorch | ||
使用bert进行中文文本分类 | ||
[![LICENSE](https://img.shields.io/badge/license-Anti%20996-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE) | ||
|
||
中文文本分类,Bert,ERNIE,基于pytorch,开箱即用。 | ||
|
||
## 介绍 | ||
模型介绍、数据流动过程:还没写完,写好之后再贴博客地址。 | ||
|
||
|
||
## 环境 | ||
python 3.7 | ||
pytorch 1.1 | ||
tqdm | ||
sklearn | ||
tensorboardX | ||
[pytorch_pretrained_bert](https://github.com/huggingface/pytorch-transformers) | ||
|
||
## 中文数据集 | ||
我从[THUCNews](http://thuctc.thunlp.org/)中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。 | ||
|
||
类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。 | ||
|
||
数据集划分: | ||
|
||
数据集|数据量 | ||
--|-- | ||
训练集|18万 | ||
验证集|1万 | ||
测试集|1万 | ||
|
||
|
||
### 更换自己的数据集 | ||
- 按照我数据集的格式来格式化你的中文数据集。 | ||
|
||
|
||
## 效果 | ||
|
||
模型|acc|备注 | ||
--|--|-- | ||
bert|94.04%|bert + fc | ||
ERNIE|92.75%|效果略差 | ||
|
||
CNN、RNN、DPCNN、RCNN、RNN+Attention、FastText等模型效果,请见我另外一个[仓库](https://github.com/649453932/Chinese-Text-Classification-Pytorch)。 | ||
|
||
## 预训练语言模型 | ||
bert模型放在 bert_pretain目录下,ERNIE模型放在ERNIE_pretrain目录下,每个目录下都是三个文件: | ||
- pytorch_model.bin | ||
- bert_config.json | ||
- vocab.txt | ||
|
||
预训练模型下载地址: | ||
bert_Chinese: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz 来自[这里](https://github.com/huggingface/pytorch-transformers) | ||
ERNIE_Chinese: http://image.nghuyong.top/ERNIE.zip 来自[这里](https://github.com/nghuyong/ERNIE-Pytorch) | ||
解压后,按照上面说的放在对应目录下,文件名称确认无误即可。 | ||
|
||
## 使用说明 | ||
下载好预训练模型就可以跑了。 | ||
``` | ||
# 训练并测试: | ||
# bert | ||
python run.py --model bert | ||
# ERNIE | ||
python run.py --model ERNIE | ||
``` | ||
|
||
### 参数 | ||
模型都在models目录下,超参定义和模型定义在同一文件中。 | ||
|
||
## 未完待续 | ||
- bert + CNN, RNN, RCNN, DPCNN等 | ||
- ERNIE + CNN, RNN, RCNN, DPCNN等 | ||
- XLNET | ||
- 另外想加个label smoothing试试效果 | ||
- 封装预测功能 | ||
|
||
|
||
## 对应论文 | ||
[1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding | ||
[2] ERNIE: Enhanced Representation through Knowledge Integration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
## 此处存放bert预训练模型: | ||
pytorch_model.bin | ||
bert_config.json | ||
vocab.txt | ||
|
||
## 下载地址: | ||
https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# coding: UTF-8 | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from pytorch_pretrained_bert import BertModel, BertTokenizer | ||
|
||
|
||
class Config(object): | ||
|
||
"""配置参数""" | ||
def __init__(self, dataset): | ||
self.model_name = 'ERNIE' | ||
self.train_path = dataset + '/data/train.txt' # 训练集 | ||
self.dev_path = dataset + '/data/dev.txt' # 验证集 | ||
self.test_path = dataset + '/data/test.txt' # 测试集 | ||
self.class_list = [x.strip() for x in open( | ||
dataset + '/data/class.txt').readlines()] # 类别名单 | ||
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果 | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备 | ||
|
||
self.dropout = 0.1 # 随机失活 | ||
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 | ||
self.num_classes = len(self.class_list) # 类别数 | ||
self.num_epochs = 3 # epoch数 | ||
self.batch_size = 128 # mini-batch大小 | ||
self.pad_size = 32 # 每句话处理成的长度(短填长切) | ||
self.learning_rate = 5e-5 # 学习率 | ||
self.bert_path = './ERNIE_pretrain' | ||
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path) | ||
print(self.tokenizer) | ||
self.hidden_size = 768 | ||
|
||
|
||
class Model(nn.Module): | ||
|
||
def __init__(self, config): | ||
super(Model, self).__init__() | ||
self.bert = BertModel.from_pretrained(config.bert_path) | ||
for param in self.bert.parameters(): | ||
param.requires_grad = True | ||
self.fc = nn.Linear(config.hidden_size, config.num_classes) | ||
|
||
def forward(self, x): | ||
context = x[0] | ||
mask = x[2] | ||
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False) | ||
out = self.fc(pooled) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters