Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Tokenizer is_split_into_words 参数选项不符合预期。 #3195

Closed
ZHUI opened this issue Sep 5, 2022 · 4 comments · Fixed by #3204
Closed

[BUG] Tokenizer is_split_into_words 参数选项不符合预期。 #3195

ZHUI opened this issue Sep 5, 2022 · 4 comments · Fixed by #3204
Assignees
Labels
bug Something isn't working question Further information is requested

Comments

@ZHUI
Copy link
Collaborator

ZHUI commented Sep 5, 2022

欢迎您反馈PaddleNLP使用问题,辛苦您提供以下信息,方便我们快速定位和解决问题:

  1. PaddleNLP版本:develop

Tokenizer is_split_into_words 参数选项不符合预期。

>>> ernie_tokenizer = paddlenlp.transformers.AutoTokenizer.from_pretrained("ernie-2.0-base-zh")
>>> ernie_tokenizer.convert_tokens_to_ids(["241", "##5","人"])
[5494, 9486, 8]
>>> ernie_tokenizer(["241", "##5","人"], is_split_into_words=True)
{'input_ids': [1, 5494, 9474, 9474, 317, 8, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0]}

虽然加了is_split_into_words选项, 这里的 ##5 没有转化成一个单独token,而是切分为 # # 5

@ZHUI ZHUI added the question Further information is requested label Sep 5, 2022
@ZHUI
Copy link
Collaborator Author

ZHUI commented Sep 5, 2022

该问题导致了examples/benchmark/clue/mrc/run_c3.py tokenizer之后数据超过max_seq_len

@wj-Mcat
Copy link
Contributor

wj-Mcat commented Sep 5, 2022

描述

目前paddlenlp的tokenizer是和transformers对齐,我用paddlenlp确实复现了这个问题,并且结果和hf一致,测试代码如下:

>>> from transformers.models.bert.tokenization_bert import BertTokenizer
>>> tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
>>> tokenizer(["241", "##5","人"], is_split_into_words=True)
{'input_ids': [101, 22343, 1001, 1001, 1019, 1756, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

不过这个问题也是比较好解决:

  • 先转化为string: 因为tokens中包含subword,转化为string的过程中会将subword给合并
  • 再直接调用tokenizer.__call__

示例代码如下所示:

>>> import paddlenlp
>>> ernie_tokenizer = paddlenlp.transformers.AutoTokenizer.from_pretrained("ernie-2.0-base-zh")
>>> sentence = ernie_tokenizer.convert_tokens_to_string(["241", "##5","人"])
>>> sentence
'2415 人'
>>> ernie_tokenizer(sentence, return_attention_mask=True)
{'input_ids': [1, 5494, 9486, 8, 2], 'token_type_ids': [0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1]}

总结

新版本tokenizer fix了一些问题,也添加了一些新特性,所以老版本可能会有些不兼容的地方。

其实examples中的一些写法可以精简代码,减少一些与huggingface transformers概念上的不一致,比如Pad、Stack等。

@ZHUI
Copy link
Collaborator Author

ZHUI commented Sep 5, 2022

  1. is_split_into_words参数的目前行为,不符合参数预期?如果是的话,还是框架尽快修复。
  2. @LiuChiachi 此处 c3 数据集,由于此问题,导致 input_ids 超过 max_seq_len,辛苦尽快修复。

for choice in choice_list:
tokens_c = tokenizer.tokenize(choice.lower())
_truncate_seq_tuple(tokens_t, tokens_q, tokens_c,
max_seq_length - 4)
tokens_c = tokens_q + ["[SEP]"] + tokens_c
tokens_t_list.append(tokens_t)
tokens_c_list.append(tokens_c)
new_data = tokenizer(tokens_t_list,
text_pair=tokens_c_list,
is_split_into_words=True)

@wj-Mcat wj-Mcat added the bug Something isn't working label Sep 6, 2022
@wj-Mcat
Copy link
Contributor

wj-Mcat commented Sep 6, 2022

PR的offline测试代码如下所示:

def test_bad_case():
    ernie_tokenizer: PretrainedTokenizer = paddlenlp.transformers.AutoTokenizer.from_pretrained("ernie-2.0-base-zh")
    tokens = ["241", "##5","人"]
    assert ernie_tokenizer([tokens], is_split_into_words='token')['input_ids'] == [[1, 5494, 9486, 8, 2]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants