Skip to content

Commit a4f8019

Browse files
authored
Merge pull request PaddlePaddle#49 from LiuChiachi/fix-lstm-distill-sst2-bug
Update usage of tokenizer in glue and distill lstm
2 parents f4d4eac + 920688f commit a4f8019

File tree

4 files changed

+9
-23
lines changed

4 files changed

+9
-23
lines changed

examples/glue/run_glue.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -264,26 +264,12 @@ def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1):
264264
else:
265265
example = tokenizer(
266266
example[0], text_pair=example[1], max_seq_len=max_seq_length)
267-
'''
268-
tokens_raw = [tokenizer(l) for l in example]
269-
# Truncate to the truncate_length,
270-
tokens_trun = _truncate_seqs(tokens_raw, max_seq_length)
271-
# Concate the sequences with special tokens
272-
tokens_trun[0] = [tokenizer.cls_token] + tokens_trun[0]
273-
tokens, segment_ids, _ = _concat_seqs(tokens_trun, [[tokenizer.sep_token]] *
274-
len(tokens_trun))
275-
# Convert the token to ids
276-
input_ids = tokenizer.convert_tokens_to_ids(tokens)
277-
valid_length = len(input_ids)
278-
# The mask has 1 for real tokens and 0 for padding tokens. Only real
279-
# tokens are attended to.
280-
# input_mask = [1] * len(input_ids)
281-
'''
267+
282268
if not is_test:
283-
return example['input_ids'], example['segment_ids'], len(example[
269+
return example['input_ids'], example['token_type_ids'], len(example[
284270
'input_ids']), label
285271
else:
286-
return example['input_ids'], example['segment_ids'], len(example[
272+
return example['input_ids'], example['token_type_ids'], len(example[
287273
'input_ids'])
288274

289275

@@ -312,7 +298,7 @@ def do_train(args):
312298
train_dataset, batch_size=args.batch_size, shuffle=True)
313299
batchify_fn = lambda samples, fn=Tuple(
314300
Pad(axis=0, pad_val=tokenizer.pad_token_id), # input
315-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # segment
301+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment
316302
Stack(), # length
317303
Stack(dtype="int64" if train_dataset.get_labels() else "float32") # label
318304
): [data for i, data in enumerate(fn(samples)) if i != 2]

examples/model_compression/distill_lstm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ wget https://paddlenlp.bj.bcebos.com/data/senta_word_dict.txt
6262
cd ../../glue
6363
export CUDA_VISIBLE_DEVICES=0
6464
export TASK_NAME=SST-2
65-
python -u ./run_bert_finetune.py \
65+
python -u ./run_glue.py \
6666
--model_type bert \
6767
--model_name_or_path bert-base-uncased \
6868
--task_name $TASK_NAME \

examples/model_compression/distill_lstm/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def create_distill_loader(task_name,
282282
if task_name == 'qqp':
283283
batchify_fn = lambda samples, fn=Tuple(
284284
Pad(axis=0, pad_val=tokenizer.pad_token_id), # bert input
285-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # bert segment
285+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # bert segment
286286
Pad(axis=0, pad_val=pad_val), # small input_ids
287287
Stack(dtype="int64"), # small seq len
288288
Pad(axis=0, pad_val=pad_val), # small input_ids
@@ -292,7 +292,7 @@ def create_distill_loader(task_name,
292292
else:
293293
batchify_fn = lambda samples, fn=Tuple(
294294
Pad(axis=0, pad_val=tokenizer.pad_token_id), # bert input
295-
Pad(axis=0, pad_val=tokenizer.pad_token_id), # bert segment
295+
Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # bert segment
296296
Pad(axis=0, pad_val=pad_val), # small input_ids
297297
Stack(dtype="int64"), # small seq len
298298
Stack(dtype="int64") # small label

examples/model_compression/distill_lstm/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def _concat_seqs(seqs, separators, seq_mask=0, separator_mask=1):
160160
is_split_into_words=is_tokenized)
161161

162162
if not is_test:
163-
return example['input_ids'], example['segment_ids'], len(example[
163+
return example['input_ids'], example['token_type_ids'], len(example[
164164
'input_ids']), label
165165

166-
return example['input_ids'], example['segment_ids'], len(example[
166+
return example['input_ids'], example['token_type_ids'], len(example[
167167
'input_ids'])

0 commit comments

Comments
 (0)