Skip to content

Commit

Permalink
初次提交
Browse files Browse the repository at this point in the history
  • Loading branch information
649453932 committed Jul 22, 2019
1 parent e71c415 commit fc8e882
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 14 deletions.
7 changes: 7 additions & 0 deletions ERNIE_pretrain/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## 此处存放ERNIE预训练模型:
pytorch_model.bin
bert_config.json
vocab.txt

## 下载地址:
http://image.nghuyong.top/ERNIE.zip
80 changes: 79 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,80 @@
# Bert-Chinese-Text-Classification-Pytorch
使用bert进行中文文本分类
[![LICENSE](https://img.shields.io/badge/license-Anti%20996-blue.svg)](https://github.com/996icu/996.ICU/blob/master/LICENSE)

中文文本分类,Bert,ERNIE,基于pytorch,开箱即用。

## 介绍
模型介绍、数据流动过程:还没写完,写好之后再贴博客地址。


## 环境
python 3.7
pytorch 1.1
tqdm
sklearn
tensorboardX
[pytorch_pretrained_bert](https://github.com/huggingface/pytorch-transformers)

## 中文数据集
我从[THUCNews](http://thuctc.thunlp.org/)中抽取了20万条新闻标题,已上传至github,文本长度在20到30之间。一共10个类别,每类2万条。数据以字为单位输入模型。

类别:财经、房产、股票、教育、科技、社会、时政、体育、游戏、娱乐。

数据集划分:

数据集|数据量
--|--
训练集|18万
验证集|1万
测试集|1万


### 更换自己的数据集
- 按照我数据集的格式来格式化你的中文数据集。


## 效果

模型|acc|备注
--|--|--
bert|94.04%|bert + fc
ERNIE|92.75%|效果略差

CNN、RNN、DPCNN、RCNN、RNN+Attention、FastText等模型效果,请见我另外一个[仓库](https://github.com/649453932/Chinese-Text-Classification-Pytorch)

## 预训练语言模型
bert模型放在 bert_pretain目录下,ERNIE模型放在ERNIE_pretrain目录下,每个目录下都是三个文件:
- pytorch_model.bin
- bert_config.json
- vocab.txt

预训练模型下载地址:
bert_Chinese: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz 来自[这里](https://github.com/huggingface/pytorch-transformers)
ERNIE_Chinese: http://image.nghuyong.top/ERNIE.zip 来自[这里](https://github.com/nghuyong/ERNIE-Pytorch)
解压后,按照上面说的放在对应目录下,文件名称确认无误即可。

## 使用说明
下载好预训练模型就可以跑了。
```
# 训练并测试:
# bert
python run.py --model bert
# ERNIE
python run.py --model ERNIE
```

### 参数
模型都在models目录下,超参定义和模型定义在同一文件中。

## 未完待续
- bert + CNN, RNN, RCNN, DPCNN等
- ERNIE + CNN, RNN, RCNN, DPCNN等
- XLNET
- 另外想加个label smoothing试试效果
- 封装预测功能


## 对应论文
[1] BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[2] ERNIE: Enhanced Representation through Knowledge Integration
7 changes: 7 additions & 0 deletions bert_pretrain/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## 此处存放bert预训练模型:
pytorch_model.bin
bert_config.json
vocab.txt

## 下载地址:
https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz
49 changes: 49 additions & 0 deletions models/ERNIE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# coding: UTF-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from pytorch_pretrained_bert import BertModel, BertTokenizer


class Config(object):

"""配置参数"""
def __init__(self, dataset):
self.model_name = 'ERNIE'
self.train_path = dataset + '/data/train.txt' # 训练集
self.dev_path = dataset + '/data/dev.txt' # 验证集
self.test_path = dataset + '/data/test.txt' # 测试集
self.class_list = [x.strip() for x in open(
dataset + '/data/class.txt').readlines()] # 类别名单
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备

self.dropout = 0.1 # 随机失活
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.batch_size = 128 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.bert_path = './ERNIE_pretrain'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
print(self.tokenizer)
self.hidden_size = 768


class Model(nn.Module):

def __init__(self, config):
super(Model, self).__init__()
self.bert = BertModel.from_pretrained(config.bert_path)
for param in self.bert.parameters():
param.requires_grad = True
self.fc = nn.Linear(config.hidden_size, config.num_classes)

def forward(self, x):
context = x[0]
mask = x[2]
_, pooled = self.bert(context, attention_mask=mask, output_all_encoded_layers=False)
out = self.fc(pooled)
return out
7 changes: 3 additions & 4 deletions models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@ def __init__(self, dataset):
self.dropout = 0.1 # 随机失活
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.num_epochs = 2 # epoch数
self.batch_size = 128 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.bert_path = 'bert_pretrain/bert-base-chinese.tar.gz'
self.vocab_path = 'bert_pretrain/bert-base-chinese-vocab.txt'
self.tokenizer = BertTokenizer.from_pretrained(self.vocab_path)
self.bert_path = './bert_pretrain'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768


Expand Down
7 changes: 2 additions & 5 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
from utils import build_dataset, build_iterator, get_time_dif

parser = argparse.ArgumentParser(description='Chinese Text Classification')
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN')
# parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
parser.add_argument('--model', type=str, required=True, help='choose a model: Bert, ERNIE')
args = parser.parse_args()


Expand All @@ -27,7 +25,7 @@

start_time = time.time()
print("Loading data...")
train_data, dev_data, test_data = build_dataset(config, args.word)
train_data, dev_data, test_data = build_dataset(config)
train_iter = build_iterator(train_data, config)
dev_iter = build_iterator(dev_data, config)
test_iter = build_iterator(test_data, config)
Expand All @@ -36,5 +34,4 @@

# train
model = x.Model(config).to(config.device)
# init_network(model)
train(config, model, train_iter, dev_iter, test_iter)
6 changes: 3 additions & 3 deletions train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def train(config, model, train_iter, dev_iter, test_iter):
# optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
optimizer = BertAdam(optimizer_grouped_parameters,
lr=config.learning_rate,
warmup=0.1,
t_total=4218)
warmup=0.05,
t_total=len(train_iter) * config.num_epochs)
total_batch = 0 # 记录进行到多少batch
dev_best_loss = float('inf')
last_improve = 0 # 记录上次验证集loss下降的batch数
Expand All @@ -54,7 +54,7 @@ def train(config, model, train_iter, dev_iter, test_iter):
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
if total_batch % 100 == 0:
if total_batch % 100 == 0:
# 每多少轮输出在训练集和验证集上的效果
true = labels.data.cpu()
predic = torch.max(outputs.data, 1)[1].cpu()
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
PAD = '[PAD]' # padding符号


def build_dataset(config, ues_word):
def build_dataset(config):

def load_dataset(path, pad_size=32):
contents = []
Expand Down

0 comments on commit fc8e882

Please sign in to comment.