Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FasterTokenizer on PPMiniLM #1542

Merged
merged 19 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
supoprts pruning
  • Loading branch information
LiuChiachi committed Dec 31, 2021
commit 51af8689be6c5a733caeca6002dcff0cf7b2eec7
1 change: 1 addition & 0 deletions examples/model_compression/pp-minilm/pruning/export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MODEL_PATH=$1
TASK_NAME=$2
python export_model.py --model_type ppminilm \
--task_name ${TASK_NAME} \
--model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \
--sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \
--static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \
Expand Down
87 changes: 64 additions & 23 deletions examples/model_compression/pp-minilm/pruning/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,25 @@
import random
import time
import json
import distutils.util
from functools import partial

import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.fluid.core as core

from paddlenlp.transformers import PPMiniLMModel
from paddlenlp.utils.log import logger
from paddlenlp.experimental import FasterTokenizer

from paddleslim.nas.ofa import OFA, utils
from paddleslim.nas.ofa.convert_super import Convert, supernet
from paddleslim.nas.ofa.layers import BaseBlock

sys.path.append("../")
from data import MODEL_CLASSES
from data import MODEL_CLASSES, METRIC_CLASSES


def ppminilm_forward(self,
Expand All @@ -44,14 +48,23 @@ def ppminilm_forward(self,
attention_mask=None):
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
if self.use_faster_tokenizer:
input_ids, token_type_ids = self.tokenizer(
text=input_ids,
text_pair=token_type_ids,
max_seq_len=self.max_seq_len)
if attention_mask is None:
attention_mask = paddle.unsqueeze(
(input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2])
embedding_output = self.embeddings(input_ids, token_type_ids, position_ids)
encoded_layer = self.encoder(embedding_output, attention_mask)
pooled_output = self.pooler(encoded_layer)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids)

return encoded_layer, pooled_output
encoder_outputs = self.encoder(embedding_output, attention_mask)
sequence_output = encoder_outputs
pooled_output = self.pooler(sequence_output)
return sequence_output, pooled_output


PPMiniLMModel.forward = ppminilm_forward
Expand Down Expand Up @@ -79,6 +92,13 @@ def parse_args():
list(classes[-1].pretrained_init_configuration.keys())
for classes in MODEL_CLASSES.values()
], [])), )
parser.add_argument(
"--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train selected in the list: " +
", ".join(METRIC_CLASSES.keys()), )
parser.add_argument(
"--sub_model_output_dir",
default=None,
Expand All @@ -98,6 +118,11 @@ def parse_args():
type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.", )
parser.add_argument(
"--save_inference_model_with_tokenizer",
type=distutils.util.strtobool,
default=True,
help="Whether to save inference model with tokenizer.")
parser.add_argument(
"--n_gpu",
type=int,
Expand All @@ -117,20 +142,10 @@ def parse_args():
return args


def export_static_model(model, model_path, max_seq_length):
input_shape = [
paddle.static.InputSpec(
shape=[None, max_seq_length], dtype='int64'),
paddle.static.InputSpec(
shape=[None, max_seq_length], dtype='int64')
]
net = paddle.jit.to_static(model, input_spec=input_shape)
paddle.jit.save(net, model_path)


def do_train(args):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数命名 do_train -> export_model ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢指出,export_model和脚本名重复,为了维持一样的格式,目前改成了do_export

paddle.set_device("gpu" if args.n_gpu else "cpu")
args.model_type = args.model_type.lower()
args.task_name = args.task_name.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config_path = os.path.join(args.model_name_or_path, 'model_config.json')
cfg_dict = dict(json.loads(open(config_path).read()))
Expand All @@ -151,6 +166,7 @@ def do_train(args):

model = model_class.from_pretrained(
args.model_name_or_path, num_classes=num_labels)
model.use_faster_tokenizer = True

origin_model = model_class.from_pretrained(
args.model_name_or_path, num_classes=num_labels)
Expand Down Expand Up @@ -183,11 +199,33 @@ def do_train(args):
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
sublayer.num_heads = int(args.width_mult * sublayer.num_heads)

origin_model_new = ofa_model.export(
best_config,
input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],
input_dtypes=['int64', 'int64'],
origin_model=origin_model)
is_text_pair = True
if args.task_name in ('tnews', 'iflytek', 'cluewsc2020'):
is_text_pair = False

if args.save_inference_model_with_tokenizer:
ofa_model.model.add_faster_tokenizer_op()
if is_text_pair:
origin_model_new = ofa_model.export(
best_config,
input_shapes=[[1], [1]],
input_dtypes=[
core.VarDesc.VarType.STRINGS, core.VarDesc.VarType.STRINGS
],
origin_model=origin_model)
else:
origin_model_new = ofa_model.export(
best_config,
input_shapes=[1],
input_dtypes=core.VarDesc.VarType.STRINGS,
origin_model=origin_model)
else:
origin_model_new = ofa_model.export(
best_config,
input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]],
input_dtypes=['int64', 'int64'],
origin_model=origin_model)

for name, sublayer in origin_model_new.named_sublayers():
if isinstance(sublayer, paddle.nn.MultiHeadAttention):
sublayer.num_heads = int(args.width_mult * sublayer.num_heads)
Expand All @@ -200,8 +238,11 @@ def do_train(args):
model_to_save.save_pretrained(output_dir)

if args.static_sub_model != None:
export_static_model(origin_model_new, args.static_sub_model,
args.max_seq_length)
origin_model_new.use_faster_tokenizer = True
origin_model_new.to_static(
args.static_sub_model,
use_faster_tokenizer=args.save_inference_model_with_tokenizer,
is_text_pair=is_text_pair)


def print_arguments(args):
Expand Down
5 changes: 5 additions & 0 deletions examples/model_compression/pp-minilm/pruning/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def ppminilm_forward(self,
token_type_ids=None,
position_ids=None,
attention_mask=[None, None]):
if self.use_faster_tokenizer:
input_ids, token_type_ids = self.tokenizer(
text=input_ids,
text_pair=token_type_ids,
max_seq_len=self.max_seq_len)
wtype = self.pooler.dense.fn.weight.dtype if hasattr(
self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype
if attention_mask[0] is None:
Expand Down
20 changes: 10 additions & 10 deletions paddlenlp/transformers/ppminilm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ def init_weights(self, layer):
elif isinstance(layer, nn.LayerNorm):
layer._epsilon = 1e-12

def add_faster_tokenizer_op(self):
self.ppminilm.tokenizer = FasterTokenizer(
self.ppminilm.vocab,
do_lower_case=self.ppminilm.do_lower_case,
is_split_into_words=self.ppminilm.is_split_into_words)

def to_static(self,
output_path,
use_faster_tokenizer=True,
Expand All @@ -161,25 +167,23 @@ def to_static(self,
self.use_faster_tokenizer = use_faster_tokenizer
# Convert to static graph with specific input description
if self.use_faster_tokenizer:
self.add_faster_tokenizer_op()
if is_text_pair:
model = paddle.jit.to_static(
self,
input_spec=[

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么 faster 版本的输入 shape 是 [None], 非 Faster 版本的 shape 是 [None, None]?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 to_static 逻辑是放在 FasterPretrainedModel 里还是暴露给 FasterTokenizer 的用户比较合适?@Steffy-zxf
@wawltor @LiuChiachi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=core.VarDesc.VarType.STRINGS时shape参数无效

paddle.static.InputSpec(
shape=[None, None],
dtype=core.VarDesc.VarType.STRINGS),
shape=[None], dtype=core.VarDesc.VarType.STRINGS),
paddle.static.InputSpec(
shape=[None, None],
dtype=core.VarDesc.VarType.STRINGS)
shape=[None], dtype=core.VarDesc.VarType.STRINGS)
])
else:

model = paddle.jit.to_static(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基类 FasterPretrainedModel 已经实现了 to_static 函数,此处是否直接调用基类函数即可?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入了use_faster_tokenizer以及pad_to_max_seq_len2个参数,以及增加了对text pair 为输入的模型的导出,求 @ZeyuChen @guoshengCS 能帮忙看一下~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除了pad_to_max_len参数,因为不是必要的

self,
input_spec=[
paddle.static.InputSpec(
shape=[None, None],
dtype=core.VarDesc.VarType.STRINGS)
shape=[None], dtype=core.VarDesc.VarType.STRINGS)
])
else:
model = paddle.jit.to_static(
Expand Down Expand Up @@ -283,10 +287,6 @@ def __init__(self,
self.is_split_into_words = is_split_into_words
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
self.tokenizer = FasterTokenizer(
self.vocab,
do_lower_case=self.do_lower_case,
is_split_into_words=self.is_split_into_words)
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.TruncatedNormal(
mean=0.0, std=self.initializer_range))
Expand Down