Skip to content

Commit

Permalink
Add select_device and init_from_ckpt arg for Bi-LSTM distillation (Pa…
Browse files Browse the repository at this point in the history
…ddlePaddle#41)

* add select_device and init_from_ckpt arg

* fix distill lstm readme bug

* update paddnlp install version in readme
  • Loading branch information
LiuChiachi authored Feb 27, 2021
1 parent 09ca61e commit 9db730d
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
8 changes: 5 additions & 3 deletions examples/model_compression/distill_lstm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
## 简介
本目录下的实验是将特定任务下BERT模型的知识蒸馏到基于Bi-LSTM的小模型中,主要参考论文[《Distilling Task-Specific Knowledge from BERT into Simple Neural Networks》](https://arxiv.org/abs/1903.12136)实现。

在模型蒸馏中,较大的模型(在本例中是BERT)通常被称为教师模型,较小的模型(在本例中是Bi-LSTM)通常被成为学生模型。知识的蒸馏通常是通过模型学习蒸馏相关的损失函数实现,在本实验中,损失函数是均方误差损失函数,传入函数的两个参数分别是学生模型的输出和教师模型的输出。
在模型蒸馏中,较大的模型(在本例中是BERT)通常被称为教师模型,较小的模型(在本例中是Bi-LSTM)通常被称为学生模型。知识的蒸馏通常是通过模型学习蒸馏相关的损失函数实现,在本实验中,损失函数是均方误差损失函数,传入函数的两个参数分别是学生模型的输出和教师模型的输出。

[论文](https://arxiv.org/abs/1903.12136)的模型蒸馏阶段,作者为了能让教师模型表达出更多的知识供学生模型学习,对训练数据进行了数据增强。作者使用了三种数据增强方式,分别是:

1. Masking,即以一定的概率将原数据中的word token替换成`[MASK]`

2. POS—guided word replacement,即以一定的概率将原数据中的词用与其有相同POS tag的词替换;

3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
3. n-gram sampling,即以一定的概率,从每条数据中采样n-gram,其中n的范围可通过人工设置。

通过数据增强,可以产生更多无标签的训练数据,在训练过程中,学生模型可借助教师模型的“暗知识”,在更大的数据集上进行训练,产生更好的蒸馏效果。需要指出的是,实验只使用了第1和第3种数据增强方式。
在英文数据集任务上,本文使用了Google News语料[预训练的Word Embedding](https://code.google.com/archive/p/word2vec/)初始化小模型的Embedding层。

本实验分为三个训练过程:在特定任务上对BERT的fine-tuning、在特定任务上对基于Bi-LSTM的小模型的训练(用于评价蒸馏效果)、将BERT模型的知识蒸馏到基于Bi-LSTM的小模型上。
Expand All @@ -32,7 +34,7 @@
另外,本项目还依赖paddlenlp,可以使用下面的命令进行安装:

```shell
pip install paddlenlp==2.0.0rc
pip install paddlenlp\>=2.0rc
```

## 数据、预训练模型介绍及获取
Expand Down
12 changes: 12 additions & 0 deletions examples/model_compression/distill_lstm/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ def parse_args():
default='models',
help="Directory to save models .")

parser.add_argument(
"--init_from_ckpt",
type=str,
default=None,
help="The path of layer and optimizer to be loaded.")

parser.add_argument(
"--whole_word_mask",
action="store_true",
Expand Down Expand Up @@ -143,5 +149,11 @@ def parse_args():
help="Random seed for model parameter initialization, data augmentation and so on."
)

parser.add_argument(
"--select_device",
default="gpu",
choices=["gpu", "cpu", "xpu"],
help="Device selected for inference.")

args = parser.parse_args()
return args
6 changes: 6 additions & 0 deletions examples/model_compression/distill_lstm/bert_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def evaluate(task_name, model, metric, data_loader):


def do_train(agrs):
device = paddle.set_device(args.select_device)
train_data_loader, dev_data_loader = create_distill_loader(
args.task_name,
model_name=args.model_name,
Expand Down Expand Up @@ -105,6 +106,11 @@ def do_train(agrs):

print("Start to distill student model.")

if args.init_from_ckpt:
model.set_state_dict(paddle.load(args.init_from_ckpt + ".pdparams"))
optimizer.set_state_dict(paddle.load(args.init_from_ckpt + ".pdopt"))
print("Loaded checkpoint from %s" % args.init_from_ckpt)

global_step = 0
tic_train = time.time()
for epoch in range(args.max_epoch):
Expand Down
6 changes: 6 additions & 0 deletions examples/model_compression/distill_lstm/small.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def evaluate(task_name, model, loss_fct, metric, data_loader):


def do_train(args):
device = paddle.set_device(args.select_device)
metric_class = TASK_CLASSES[args.task_name][1]
metric = metric_class()
if args.task_name == 'qqp':
Expand Down Expand Up @@ -165,6 +166,11 @@ def do_train(args):
optimizer = paddle.optimizer.Adam(
learning_rate=args.lr, parameters=model.parameters())

if args.init_from_ckpt:
model.set_state_dict(paddle.load(args.init_from_ckpt + ".pdparams"))
optimizer.set_state_dict(paddle.load(args.init_from_ckpt + ".pdopt"))
print("Loaded checkpoint from %s" % args.init_from_ckpt)

global_step = 0
tic_train = time.time()
for epoch in range(args.max_epoch):
Expand Down

0 comments on commit 9db730d

Please sign in to comment.