Skip to content

Commit ede5a04

Browse files
authored
Merge pull request #439 from will-am/chinese_poetry
Add preprocessor for generating Chinese poetry.
2 parents 66f18a1 + 3bbe91d commit ede5a04

File tree

6 files changed

+217
-14
lines changed

6 files changed

+217
-14
lines changed

generate_chinese_poetry/README.md

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,111 @@
1-
[TBD]
1+
# 中国古诗生成
2+
3+
## 简介
4+
基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。
5+
6+
模型中的编码器、解码器均使用堆叠双向LSTM (stacked bi-directional LSTM),默认均为3层,带有注意力单元(attention)。
7+
8+
以下是本例的简要目录结构及说明:
9+
10+
```text
11+
.
12+
├── data # 存储训练数据及字典
13+
│ ├── download.sh # 下载原始数据
14+
├── README.md # 文档
15+
├── index.html # 文档(html格式)
16+
├── preprocess.py # 原始数据预处理
17+
├── generate.py # 生成诗句脚本
18+
├── network_conf.py # 模型定义
19+
├── reader.py # 数据读取接口
20+
├── train.py # 训练脚本
21+
└── utils.py # 定义实用工具函数
22+
```
23+
24+
## 数据处理
25+
### 原始数据来源
26+
本例使用[中华古诗词数据库](https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。
27+
28+
### 原始数据下载
29+
```bash
30+
cd data && ./download.sh && cd ..
31+
```
32+
### 数据预处理
33+
```bash
34+
python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt
35+
```
36+
37+
上述脚本执行完后将生成处理好的训练数据poems.txt和字典dict.txt。字典的构建以字为单位,使用出现频数至少为10的字构建字典。
38+
39+
poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。在诗内容中,诗句之间用`.`分隔。
40+
41+
训练数据示例:
42+
```text
43+
登鸛雀樓 王之渙 白日依山盡.黃河入海流.欲窮千里目.更上一層樓
44+
觀獵 李白 太守耀清威.乘閑弄晚暉.江沙橫獵騎.山火遶行圍.箭逐雲鴻落.鷹隨月兔飛.不知白日暮.歡賞夜方歸
45+
晦日重宴 陳嘉言 高門引冠蓋.下客抱支離.綺席珍羞滿.文場翰藻摛.蓂華彫上月.柳色藹春池.日斜歸戚里.連騎勒金羈
46+
```
47+
48+
模型训练时,使用每一诗句作为模型输入,下一诗句作为预测目标。
49+
50+
51+
## 模型训练
52+
训练脚本[train.py](./train.py)中的命令行参数可以通过`python train.py --help`查看。主要参数说明如下:
53+
- `num_passes`: 训练pass数
54+
- `batch_size`: batch大小
55+
- `use_gpu`: 是否使用GPU
56+
- `trainer_count`: trainer数目,默认为1
57+
- `save_dir_path`: 模型存储路径,默认为当前目录下models目录
58+
- `encoder_depth`: 模型中编码器LSTM深度,默认为3
59+
- `decoder_depth`: 模型中解码器LSTM深度,默认为3
60+
- `train_data_path`: 训练数据路径
61+
- `word_dict_path`: 数据字典路径
62+
- `init_model_path`: 初始模型路径,从头训练时无需指定
63+
64+
### 训练执行
65+
```bash
66+
python train.py \
67+
--num_passes 50 \
68+
--batch_size 256 \
69+
--use_gpu True \
70+
--trainer_count 1 \
71+
--save_dir_path models \
72+
--train_data_path data/poems.txt \
73+
--word_dict_path data/dict.txt \
74+
2>&1 | tee train.log
75+
```
76+
每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。
77+
78+
### 最优模型参数
79+
寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。
80+
```bash
81+
python -c 'import utils; utils.find_optiaml_pass("./train.log")'
82+
```
83+
84+
## 生成诗句
85+
使用[generate.py](./generate.py)脚本对输入诗句生成下一诗句,命令行参数可通过`python generate.py --help`查看。
86+
主要参数说明如下:
87+
- `model_path`: 训练好的模型参数文件
88+
- `word_dict_path`: 数据字典路径
89+
- `test_data_path`: 输入数据路径
90+
- `batch_size`: batch大小,默认为1
91+
- `beam_size`: beam search中搜索范围大小,默认为5
92+
- `save_file`: 输出保存路径
93+
- `use_gpu`: 是否使用GPU
94+
95+
### 执行生成
96+
例如将诗句 `孤帆遠影碧空盡` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令:
97+
```bash
98+
python generate.py \
99+
--model_path models/pass_00049.tar.gz \
100+
--word_dict_path data/dict.txt \
101+
--test_data_path input.txt \
102+
--save_file output.txt
103+
```
104+
生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下:
105+
```text
106+
-9.6987 萬 壑 清 風 黃 葉 多
107+
-10.0737 萬 里 遠 山 紅 葉 深
108+
-10.4233 萬 壑 清 波 紅 一 流
109+
-10.4802 萬 壑 清 風 黃 葉 深
110+
-10.9060 萬 壑 清 風 紅 葉 多
111+
```
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
git clone https://github.com/chinese-poetry/chinese-poetry.git
4+
5+
if [ ! -d raw ]
6+
then
7+
mkdir raw
8+
fi
9+
10+
mv chinese-poetry/json/poet.tang.* raw/
11+
rm -rf chinese-poetry

generate_chinese_poetry/generate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
2828
for j in xrange(beam_size):
2929
end_pos = gen_sen_idx[i * beam_size + j]
3030
fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join(
31-
id_to_text[w] for w in beam_result[1][start_pos:end_pos]))))
31+
id_to_text[w] for w in beam_result[1][start_pos:end_pos - 1]))))
3232
start_pos = end_pos + 2
3333
fout.write("\n")
3434
fout.flush
@@ -80,9 +80,11 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
8080
encoder_hidden_dim=512,
8181
decoder_depth=3,
8282
decoder_hidden_dim=512,
83-
is_generating=True,
83+
bos_id=0,
84+
eos_id=1,
85+
max_length=9,
8486
beam_size=beam_size,
85-
max_length=10)
87+
is_generating=True)
8688

8789
inferer = paddle.inference.Inference(
8890
output_layer=beam_gen, parameters=parameters)

generate_chinese_poetry/network_conf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ def encoder_decoder_network(word_count,
7373
encoder_hidden_dim,
7474
decoder_depth,
7575
decoder_hidden_dim,
76+
bos_id,
77+
eos_id,
78+
max_length,
7679
beam_size=10,
77-
max_length=15,
7880
is_generating=False):
7981
src_emb = paddle.layer.embedding(
8082
input=paddle.layer.data(
@@ -106,8 +108,8 @@ def encoder_decoder_network(word_count,
106108
name=decoder_group_name,
107109
step=_attended_decoder_step,
108110
input=group_inputs + [gen_trg_emb],
109-
bos_id=0,
110-
eos_id=1,
111+
bos_id=bos_id,
112+
eos_id=eos_id,
111113
beam_size=beam_size,
112114
max_length=max_length)
113115

generate_chinese_poetry/preprocess.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
import os
3+
import io
4+
import re
5+
import json
6+
import click
7+
import collections
8+
9+
10+
def build_vocabulary(dataset, cutoff=0):
11+
dictionary = collections.defaultdict(int)
12+
for data in dataset:
13+
for sent in data[2]:
14+
for char in sent:
15+
dictionary[char] += 1
16+
dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items())
17+
dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0]))
18+
vocab, _ = list(zip(*dictionary))
19+
return (u"<s>", u"<e>", u"<unk>") + vocab
20+
21+
22+
@click.command("preprocess")
23+
@click.option("--datadir", type=str, help="Path to raw data")
24+
@click.option("--outfile", type=str, help="Path to save the training data")
25+
@click.option("--dictfile", type=str, help="Path to save the dictionary file")
26+
def preprocess(datadir, outfile, dictfile):
27+
dataset = []
28+
note_pattern1 = re.compile(u"(.*?)", re.U)
29+
note_pattern2 = re.compile(u"〖.*?〗", re.U)
30+
note_pattern3 = re.compile(u"-.*?-。?", re.U)
31+
note_pattern4 = re.compile(u"(.*$", re.U)
32+
note_pattern5 = re.compile(u"。。.*)$", re.U)
33+
note_pattern6 = re.compile(u"。。", re.U)
34+
note_pattern7 = re.compile(u"[《》「」\[\]]", re.U)
35+
print("Load raw data...")
36+
for fn in os.listdir(datadir):
37+
with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f:
38+
for data in json.load(f):
39+
title = data['title']
40+
author = data['author']
41+
p = "".join(data['paragraphs'])
42+
p = "".join(p.split())
43+
p = note_pattern1.sub(u"", p)
44+
p = note_pattern2.sub(u"", p)
45+
p = note_pattern3.sub(u"", p)
46+
p = note_pattern4.sub(u"", p)
47+
p = note_pattern5.sub(u"。", p)
48+
p = note_pattern6.sub(u"。", p)
49+
p = note_pattern7.sub(u"", p)
50+
if (p == u"" or u"{" in p or u"}" in p or u"{" in p or
51+
u"}" in p or u"、" in p or u":" in p or u";" in p or
52+
u"!" in p or u"?" in p or u"●" in p or u"□" in p or
53+
u"囗" in p or u")" in p):
54+
continue
55+
paragraphs = re.split(u"。|,", p)
56+
paragraphs = filter(lambda x: len(x), paragraphs)
57+
if len(paragraphs) > 1:
58+
dataset.append((title, author, paragraphs))
59+
60+
print("Construct vocabularies...")
61+
vocab = build_vocabulary(dataset, cutoff=10)
62+
with io.open(dictfile, "w", encoding="utf8") as f:
63+
for v in vocab:
64+
f.write(v + "\n")
65+
66+
print("Write processed data...")
67+
with io.open(outfile, "w", encoding="utf8") as f:
68+
for data in dataset:
69+
title = data[0]
70+
author = data[1]
71+
paragraphs = ".".join(data[2])
72+
f.write("\t".join((title, author, paragraphs)) + "\n")
73+
74+
75+
if __name__ == "__main__":
76+
preprocess()

generate_chinese_poetry/train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def load_initial_model(model_path, parameters):
4444
@click.option(
4545
"--decoder_depth",
4646
default=3,
47-
help="The number of stacked LSTM layers in encoder.")
47+
help="The number of stacked LSTM layers in decoder.")
4848
@click.option(
4949
"--train_data_path", required=True, help="The path of trainning data.")
5050
@click.option(
@@ -75,10 +75,9 @@ def train(num_passes,
7575
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
7676

7777
# define optimization method and the trainer instance
78-
optimizer = paddle.optimizer.AdaDelta(
79-
learning_rate=1e-3,
80-
gradient_clipping_threshold=25.0,
81-
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
78+
optimizer = paddle.optimizer.Adam(
79+
learning_rate=1e-4,
80+
regularization=paddle.optimizer.L2Regularization(rate=1e-5),
8281
model_average=paddle.optimizer.ModelAverage(
8382
average_window=0.5, max_average_window=2500))
8483

@@ -88,7 +87,10 @@ def train(num_passes,
8887
encoder_depth=encoder_depth,
8988
encoder_hidden_dim=512,
9089
decoder_depth=decoder_depth,
91-
decoder_hidden_dim=512)
90+
decoder_hidden_dim=512,
91+
bos_id=0,
92+
eos_id=1,
93+
max_length=9)
9294

9395
parameters = paddle.parameters.create(cost)
9496
if init_model_path:
@@ -113,7 +115,7 @@ def event_handler(event):
113115
(event.pass_id, event.batch_id))
114116
save_model(trainer, save_path, parameters)
115117

116-
if not event.batch_id % 5:
118+
if not event.batch_id % 10:
117119
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
118120
event.pass_id, event.batch_id, event.cost, event.metrics))
119121

0 commit comments

Comments
 (0)