Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FasterTokenizer on PPMiniLM #1542

Merged
merged 19 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update modeling
  • Loading branch information
LiuChiachi committed Jan 10, 2022
commit 1a4329d6500a718d42becc7260afaa47242923c2
57 changes: 9 additions & 48 deletions examples/model_compression/pp-minilm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@
}


def get_example_for_faster_tokenizer(example,
label_list,
is_test=False,
**kwargs):
def convert_example(example,
label_list,
tokenizer=None,
is_test=False,
max_seq_length=512,
**kwargs):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
Expand Down Expand Up @@ -73,57 +75,16 @@ def get_example_for_faster_tokenizer(example,
text_list.insert(query_idx + len(query) + 2 + 1, "_")
text = "".join(text_list)
example['sentence'] = text

return example


def convert_example(example,
tokenizer,
label_list,
max_seq_length=512,
is_test=False):
"""convert a glue example into necessary features"""
if not is_test:
# `label_list == None` is for regression task
label_dtype = "int64" if label_list else "float32"
# Get the label
label = example['label']
label = np.array([label], dtype=label_dtype)
# Convert raw text to feature
if tokenizer is None:
return example
if 'sentence' in example:
example = tokenizer(example['sentence'], max_seq_len=max_seq_length)
elif 'sentence1' in example:
example = tokenizer(
example['sentence1'],
text_pair=example['sentence2'],
max_seq_len=max_seq_length)
elif 'keyword' in example: # CSL
sentence1 = " ".join(example['keyword'])
example = tokenizer(
sentence1, text_pair=example['abst'], max_seq_len=max_seq_length)
elif 'target' in example: # wsc
text, query, pronoun, query_idx, pronoun_idx = example['text'], example[
'target']['span1_text'], example['target']['span2_text'], example[
'target']['span1_index'], example['target']['span2_index']
text_list = list(text)
assert text[pronoun_idx:(pronoun_idx + len(pronoun)
)] == pronoun, "pronoun: {}".format(pronoun)
assert text[query_idx:(query_idx + len(query)
)] == query, "query: {}".format(query)
if pronoun_idx > query_idx:
text_list.insert(query_idx, "_")
text_list.insert(query_idx + len(query) + 1, "_")
text_list.insert(pronoun_idx + 2, "[")
text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]")
else:
text_list.insert(pronoun_idx, "[")
text_list.insert(pronoun_idx + len(pronoun) + 1, "]")
text_list.insert(query_idx + 2, "_")
text_list.insert(query_idx + len(query) + 2 + 1, "_")
text = "".join(text_list)
example = tokenizer(text, max_seq_len=max_seq_length)

if not is_test:
return example['input_ids'], example['token_type_ids'], label
return example['input_ids'], example['token_type_ids'], example['label']
else:
return example['input_ids'], example['token_type_ids']
4 changes: 2 additions & 2 deletions examples/model_compression/pp-minilm/finetuning/run_clue.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def do_eval(args):
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=dev_ds.label_list,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)

dev_ds = dev_ds.map(trans_func, lazy=True)
Expand Down Expand Up @@ -274,8 +274,8 @@ def do_train(args):

trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)

train_ds = train_ds.map(trans_func, lazy=True)
Expand Down
8 changes: 3 additions & 5 deletions examples/model_compression/pp-minilm/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from paddlenlp.data import Stack, Tuple, Pad

sys.path.append("../")
from data import convert_example, METRIC_CLASSES, MODEL_CLASSES, get_example_for_faster_tokenizer
from data import convert_example, METRIC_CLASSES, MODEL_CLASSES


def parse_args():
Expand Down Expand Up @@ -258,8 +258,8 @@ def convert_predict_batch(self, args, data, tokenizer, batchify_fn,
for example in data:
example = convert_example(
example,
label_list,
tokenizer,
label_list=label_list,
max_seq_length=args.max_seq_length)
examples.append(example)

Expand Down Expand Up @@ -321,9 +321,7 @@ def main():
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
else:
trans_func = partial(
get_example_for_faster_tokenizer,
label_list=dev_ds.label_list,
is_test=False)
convert_example, label_list=dev_ds.label_list, is_test=False)
dev_ds = dev_ds.map(trans_func, lazy=True)
if not args.use_faster_tokenizer:
batchify_fn = lambda samples, fn=Tuple(
Expand Down
2 changes: 1 addition & 1 deletion examples/model_compression/pp-minilm/pruning/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ def do_train(args):

trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=train_ds.label_list,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length)
train_ds = train_ds.map(trans_func, lazy=True)
train_batch_sampler = paddle.io.DistributedBatchSampler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from paddlenlp.transformers import PPMiniLMTokenizer

sys.path.append("../")
from data import convert_example, METRIC_CLASSES, MODEL_CLASSES, get_example_for_faster_tokenizer
from data import convert_example, METRIC_CLASSES, MODEL_CLASSES

parser = argparse.ArgumentParser()

Expand Down Expand Up @@ -98,16 +98,13 @@ def quant_post(args, batch_size=8, algo='avg'):

dev_ds = load_dataset("clue", args.task_name, splits="dev")
if args.use_faster_tokenizer:
trans_func = partial(
get_example_for_faster_tokenizer,
label_list=dev_ds.label_list,
max_seq_len=args.max_seq_length)
trans_func = partial(convert_example, label_list=dev_ds.label_list)
else:
tokenizer = PPMiniLMTokenizer.from_pretrained("ppminilm-6l-768h")
trans_func = partial(
convert_example,
tokenizer=tokenizer,
label_list=dev_ds.label_list,
tokenizer=tokenizer,
max_seq_length=128,
is_test=True)
dev_ds = dev_ds.map(trans_func, lazy=True)
Expand Down
24 changes: 15 additions & 9 deletions paddlenlp/transformers/ppminilm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,9 @@ def add_faster_tokenizer_op(self):
def to_static(self,
output_path,
use_faster_tokenizer=True,
is_text_pair=False,
pad_to_max_seq_len=False):
is_text_pair=False):
self.eval()
self.use_faster_tokenizer = use_faster_tokenizer
self.pad_to_max_seq_len = pad_to_max_seq_len
# Convert to static graph with specific input description
if self.use_faster_tokenizer:
self.add_faster_tokenizer_op()
Expand All @@ -180,7 +178,6 @@ def to_static(self,
shape=[None], dtype=core.VarDesc.VarType.STRINGS)
])
else:

model = paddle.jit.to_static(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

基类 FasterPretrainedModel 已经实现了 to_static 函数,此处是否直接调用基类函数即可?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加入了use_faster_tokenizer以及pad_to_max_seq_len2个参数,以及增加了对text pair 为输入的模型的导出,求 @ZeyuChen @guoshengCS 能帮忙看一下~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经删除了pad_to_max_len参数,因为不是必要的

self,
input_spec=[
Expand Down Expand Up @@ -318,11 +315,16 @@ def forward(self,
attention_mask=None):
r"""
Args:
input_ids (Tensor):
Indices of input sequence tokens in the vocabulary. They are
numerical representations of tokens that build the input sequence.
It's data type should be `int64` and has a shape of [batch_size, sequence_length].
token_type_ids (Tensor, optional):
input_ids (Tensor, List[string]):
If `input_ids` is a Tensor object, it is an indices of input
sequence tokens in the vocabulary. They are numerical
representations of tokens that build the input sequence. It's
data type should be `int64` and has a shape of [batch_size, sequence_length].
If `input_ids` is a list of string, `self.use_faster_tokenizer`
should be True, and the network contains `faster_tokenizer`
operator.
token_type_ids (Tensor, string, optional):
If `token_type_ids` is a Tensor object:
Segment token indices to indicate different portions of the inputs.
Selected in the range ``[0, type_vocab_size - 1]``.
If `type_vocab_size` is 2, which means the inputs have two portions.
Expand All @@ -333,6 +335,10 @@ def forward(self,

Its data type should be `int64` and it has a shape of [batch_size, sequence_length].
Defaults to `None`, which means we don't add segment embeddings.

If `token_type_ids` is a list of string: `self.use_faster_tokenizer`
should be True, and the network contains `faster_tokenizer` operator.

position_ids (Tensor, optional):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
max_position_embeddings - 1]``.
Expand Down