Skip to content

Commit a45819b

Browse files
authored
Fix for issue #147 for transformers>4.17.0 (#148)
Fix for issue #147 for transformers>4.17.0
1 parent cb582ed commit a45819b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bert_score/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from packaging import version
1111
from torch.nn.utils.rnn import pad_sequence
1212
from tqdm.auto import tqdm
13-
from transformers import (AutoModel, AutoTokenizer, BertConfig, GPT2Tokenizer,
13+
from transformers import (AutoModel, AutoTokenizer, BertConfig, GPT2Tokenizer, RobertaTokenizer,
1414
RobertaConfig, XLMConfig, XLNetConfig)
1515
from transformers import __version__ as trans_version
1616

@@ -190,7 +190,7 @@ def sent_encode(tokenizer, sent):
190190
sent = sent.strip()
191191
if sent == "":
192192
return tokenizer.build_inputs_with_special_tokens([])
193-
elif isinstance(tokenizer, GPT2Tokenizer):
193+
elif isinstance(tokenizer, GPT2Tokenizer) or isinstance(tokenizer, RobertaTokenizer):
194194
# for RoBERTa and GPT-2
195195
if version.parse(trans_version) >= version.parse("4.0.0"):
196196
return tokenizer.encode(

0 commit comments

Comments
 (0)