From f624976a0a605f910b2145b13d2a23fafeb8bd34 Mon Sep 17 00:00:00 2001 From: huwenxing Date: Sun, 28 Jul 2019 16:25:48 +0800 Subject: [PATCH] update --- README.md | 4 ++-- models/bert.py | 2 +- utils.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index eabf7cd..42ed09d 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,8 @@ tensorboardX 模型|acc|备注 --|--|-- -bert|94.04%|bert + fc -ERNIE|92.75%|说好的中文碾压bert呢 +bert|94.83%|bert + fc +ERNIE|94.61%|说好的中文碾压bert呢 CNN、RNN、DPCNN、RCNN、RNN+Attention、FastText等模型效果,请见我另外一个[仓库](https://github.com/649453932/Chinese-Text-Classification-Pytorch)。 diff --git a/models/bert.py b/models/bert.py index 90aa60d..5c0fd61 100644 --- a/models/bert.py +++ b/models/bert.py @@ -20,7 +20,7 @@ def __init__(self, dataset): self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练 self.num_classes = len(self.class_list) # 类别数 - self.num_epochs = 2 # epoch数 + self.num_epochs = 3 # epoch数 self.batch_size = 128 # mini-batch大小 self.pad_size = 32 # 每句话处理成的长度(短填长切) self.learning_rate = 5e-5 # 学习率 diff --git a/utils.py b/utils.py index 470bdfb..697ba3b 100644 --- a/utils.py +++ b/utils.py @@ -4,7 +4,7 @@ import time from datetime import timedelta -PAD = '[PAD]' # padding符号 +PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号 def build_dataset(config): @@ -18,6 +18,7 @@ def load_dataset(path, pad_size=32): continue content, label = lin.split('\t') token = config.tokenizer.tokenize(content) + token = [CLS] + token seq_len = len(token) mask = [] token_ids = config.tokenizer.convert_tokens_to_ids(token)