Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FasterTokenizer on PPMiniLM #1542

Merged
merged 19 commits into from
Jan 11, 2022

Conversation

LiuChiachi
Copy link
Contributor

@LiuChiachi LiuChiachi commented Dec 31, 2021

PR types

New features

PR changes

Models | APIs

Description

Add FasterTokenizer on PPMiniLM

TODO:

  • 更新PPMiniLM类,支持训练、导出差异化:
    • 动态图Finetuning、裁剪训练不接入FasterTokenizer正常训练;
    • Finetuning、裁剪的导出支持导出带有FasterTokenizer op的图
  • 支持离线量化带有FasterTokenizer op的图

PS:
此PR正常运行需要依赖 PaddlePaddle/Paddle#38686PaddlePaddle/PaddleSlim#964 的合入

Model #Params #FLOPs Speedup AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL CLUE 平均值
BERTbase 102.3M 10.87B 1.00x 74.14 56.81 61.10 81.19 74.85 79.93 81.47 72.78
TinyBERT6 59.7M 5.44B 2.04x 72.59 55.70 57.64 79.57 73.97 76.32 80.00 70.83
UER-py RoBERTa L6-H768 59.7M 5.44B 2.04x 69.62 66.45 59.91 76.89 71.36 71.05 82.87 71.16
RBT6, Chinese 59.7M 5.44B 2.04x 73.93 56.63 59.79 79.28 73.12 77.30 80.80 71.55
ERNIE-Tiny 90.7M 4.83B 2.30x 71.55 58.34 61.41 76.81 71.46 72.04 79.13 70.11
PP-MiniLM 6L-768H 59.7M 5.44B 2.12x 74.14 57.43 61.75 81.01 76.17 86.18 79.17 73.69
PP-MiniLM 裁剪后 49.1M 4.08B 2.60x 73.91 57.44 61.64 81.10 75.59 85.86 78.53 73.44
PP-MiniLM 裁剪 + 量化后 49.2M - 9.26x 74.00 57.37 61.33 81.09 75.56 85.85 78.57 73.40

@tianxin1860 tianxin1860 self-requested a review January 5, 2022 09:23
| PP-MiniLM 裁剪后 | 49.1M | 4.08B | 2.39x | 73.91 | 57.44 | 61.64 | 81.10 | 75.59 | 85.86 | 78.53 | 73.44 |
| PP-MiniLM 量化后 | 49.2M | - | 5.35x | 74.00 | 57.37 | 61.33 | 81.09 | 75.56 | 85.85 | 78.57 | 73.40 |
| TinyBERT<sub>6</sub> | 59.7M | 5.44B | 2.04x | 72.59 | 55.70 | 57.64 | 79.57 | 73.97 | 76.32 | 80.00 | 70.83 |
| UER-py RoBERTa L6- H768 | 59.7M | 5.44B | 2.04x | 69.62 | 66.45 | 59.91 | 76.89 | 71.36 | 71.05 | 82.87 | 71.16 |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L6-H768 多了个空格

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指出,已经修改

Copy link

@tianxin1860 tianxin1860 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要把 FasterTokenizer 当做 1 个开关给用户控制么?在 FasterTokenizer 比普通 Tokenizer 显著快又没有其它劣势的条件下,我们对外就直接提供最优版本 FasterTokenizer 即可?

@@ -404,6 +409,7 @@ def print_arguments(args):
print_arguments(args)
if args.do_train:
do_train(args)
export_model(args)
if args.save_inference_model:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要 save_inference_model 和 save_inference_model_with_tokenizer 2 个命令行参数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经把save_inference_model_with_tokenizer去掉了,之前的use_faster_tokenizer可以发挥这样的功能

Copy link

@tianxin1860 tianxin1860 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave some comments

pad_to_max_seq_len=False):
self.eval()
self.use_faster_tokenizer = use_faster_tokenizer
self.pad_to_max_seq_len = pad_to_max_seq_len

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

传入的 pad_to_max_seq_len 没有使用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

动转静forward中有被用到

])
else:

model = paddle.jit.to_static(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基类 FasterPretrainedModel 已经实现了 to_static 函数,此处是否直接调用基类函数即可?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入了use_faster_tokenizer以及pad_to_max_seq_len2个参数,以及增加了对text pair 为输入的模型的导出,求 @ZeyuChen @guoshengCS 能帮忙看一下~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除了pad_to_max_len参数,因为不是必要的

if is_text_pair:
model = paddle.jit.to_static(
self,
input_spec=[

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么 faster 版本的输入 shape 是 [None], 非 Faster 版本的 shape 是 [None, None]?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 to_static 逻辑是放在 FasterPretrainedModel 里还是暴露给 FasterTokenizer 的用户比较合适?@Steffy-zxf
@wawltor @LiuChiachi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=core.VarDesc.VarType.STRINGS时shape参数无效

}
}
base_model_prefix = "ppminilm"
use_faster_tokenizer = False

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个变量加在这里感觉没有起作用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认是False,这样在训练的时候可以走不带use_faster_tokenizer的逻辑

Comment on lines +384 to +388
input_ids, token_type_ids = self.tokenizer(
text=input_ids,
text_pair=token_type_ids,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=self.pad_to_max_seq_len)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 tokenizer 是指 FasterTokenizer 吧?输入为什么是 input_ids 和 token_type_ids?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

共用了一样的参数列表,use_faster_tokenizer == True 时计算真正的input_idstoken_type_ids

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_idstoken_type_ids 分别表示的含义是 text、text_pair, 为什么不使用语义一致的变量命名?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改了docstring

Copy link

@tianxin1860 tianxin1860 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave a comment

@@ -34,6 +34,49 @@
}


def get_example_for_faster_tokenizer(example,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

和 convert_exmaple 函数的唯一区别是少了 tokenize 函数么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,区别不是很大,我把它们合成一个函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改,感谢指出:)

Copy link

@tianxin1860 tianxin1860 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave some comments

net = paddle.jit.to_static(model, input_spec=input_shape)
paddle.jit.save(net, model_path)


def do_train(args):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数命名 do_train -> export_model ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指出,export_model和脚本名重复,为了维持一样的格式,目前改成了do_export

input_dtypes=core.VarDesc.VarType.STRINGS,
origin_model=origin_model)
else:
ofa_model.model.use_faster_tokenizer = args.use_faster_tokenizer

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这里为什么需要 ofa_model.model.use_faster_tokenizer = args.use_faster_tokenizer ?
  2. ofa_model.model.use_faster_tokenizer 默认值我记得设置就是 False 为了避免训练进入 FasterTokenizer。

Copy link

@tianxin1860 tianxin1860 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体看下来,Fine-tune、裁剪、量化部分为了同时支持 FasterTokenizer 版本和普通 Tokenizer 版本,逻辑略显复杂,会引入语义不一致问题,个人建议可以只支持 FasterTokenizer 版本即可,无需支持普通 Tokenizer。

if 'sentence' in data:
batch_data.append(data['sentence'])
if len(batch_data) == batch_size:
yield {"input_ids": batch_data}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里传入的是数据 text, 变量名必须固定为 "input_ids" 么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指出,已经修改了导出text时输入的name

Comment on lines 139 to 140
"input_ids": batch_data[0],
"token_type_ids": batch_data[1]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_idstoken_type_ids 对应的数据都是 text, 必须固定用这 2 个变量名么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指出,已经修改了导出text时输入的name,改成了"text"和“text_pair“

Copy link

@tianxin1860 tianxin1860 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LiuChiachi LiuChiachi merged commit 4ac2811 into PaddlePaddle:develop Jan 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants