Skip to content

Commit b28b6d4

Browse files
tarekghericstj
andauthored
BpeTokenizer Cleanup (#7514)
* BpeTokenizer Cleanup * Apply suggestions from code review Co-authored-by: Eric StJohn <ericstj@microsoft.com> --------- Co-authored-by: Eric StJohn <ericstj@microsoft.com>
1 parent 37601f3 commit b28b6d4

File tree

2 files changed

+77
-18
lines changed

2 files changed

+77
-18
lines changed

src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ private BpeTokenizer(
320320

321321
if (beginningOfSentenceToken is not null)
322322
{
323-
if (!_vocab.TryGetValue(beginningOfSentenceToken, out int aId))
323+
if (_vocab.TryGetValue(beginningOfSentenceToken, out int aId) is false && specialTokens?.TryGetValue(beginningOfSentenceToken, out aId) is false)
324324
{
325325
throw new InvalidOperationException($"The beginning of sentence token '{beginningOfSentenceToken}' was not present in the vocabulary.");
326326
}
@@ -331,7 +331,7 @@ private BpeTokenizer(
331331

332332
if (endOfSentenceToken is not null)
333333
{
334-
if (!_vocab.TryGetValue(endOfSentenceToken, out int aId))
334+
if (_vocab.TryGetValue(endOfSentenceToken, out int aId) is false && specialTokens?.TryGetValue(endOfSentenceToken, out aId) is false)
335335
{
336336
throw new InvalidOperationException($"The end of sentence token '{endOfSentenceToken}' was not present in the vocabulary.");
337337
}
@@ -792,31 +792,30 @@ public string Decode(IEnumerable<int> ids, bool considerSpecialTokens)
792792

793793
ValueStringBuilder sb = new ValueStringBuilder();
794794

795-
bool decodeUnknownToken = _unknownTokenId.HasValue && considerSpecialTokens;
796-
797-
if (decodeUnknownToken)
795+
foreach (int id in ids)
798796
{
799-
foreach (int id in ids)
797+
if (_specialTokensReverse?.TryGetValue(id, out string? token) is true)
800798
{
801-
if (MapIdToToken(id) is string s)
799+
if (considerSpecialTokens)
802800
{
803-
sb.Append(s);
801+
sb.Append(token);
804802
}
803+
continue;
805804
}
806-
}
807-
else
808-
{
809-
foreach (int id in ids)
805+
806+
if (id == _unknownTokenId)
810807
{
811-
if (id == _unknownTokenId)
808+
if (considerSpecialTokens)
812809
{
813-
continue;
810+
Debug.Assert(UnknownToken is not null);
811+
sb.Append(UnknownToken);
814812
}
813+
continue;
814+
}
815815

816-
if (MapIdToToken(id) is string s)
817-
{
818-
sb.Append(s);
819-
}
816+
if (MapIdToToken(id) is string s)
817+
{
818+
sb.Append(s);
820819
}
821820
}
822821

test/Microsoft.ML.Tokenizers.Tests/BpeTests.cs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,66 @@ public void TestDeepSeekR1Tokenizer(string text, int[] ids, string[] tokens, (in
885885
Assert.Equal(text, tokenizer.Decode(ids, considerSpecialTokens: false));
886886
}
887887

888+
[Fact]
889+
public void TestTokenizerWithSpecialTokens()
890+
{
891+
// "https://huggingface.co/openai-community/gpt2/raw/main/vocab.json";
892+
// "https://huggingface.co/openai-community/gpt2/raw/main/merges.txt";
893+
894+
BpeOptions options = new BpeOptions(Path.Combine(@"Gpt-2", "vocab.json"), Path.Combine(@"Gpt-2", "merges.txt"))
895+
{
896+
UnknownToken = "unk",
897+
898+
SpecialTokens = new Dictionary<string, int> // SpecialTokens not part of the original vocab.json
899+
{
900+
{ "<|sos|>", 50257 },
901+
{ "<|eos|>", 50258 }
902+
},
903+
BeginningOfSentenceToken = "<|sos|>",
904+
EndOfSentenceToken = "<|eos|>"
905+
};
906+
907+
BpeTokenizer bpeTokenizer = BpeTokenizer.Create(options);
908+
Assert.True(bpeTokenizer.Vocabulary.TryGetValue(options.UnknownToken, out int unkId));
909+
910+
string text = "Hello world!\uD800";
911+
912+
var ids = bpeTokenizer.EncodeToIds(text, considerPreTokenization: false);
913+
Assert.Equal([50257, 15496, 2954, 6894, 0, 2954, 50258], ids); // space and u+D800 couldn't be encoded and produced unk tokens
914+
Assert.Equal(unkId, ids[ids.Count - 2]);
915+
Assert.Equal(options.SpecialTokens["<|sos|>"], ids[0]);
916+
Assert.Equal(options.SpecialTokens["<|eos|>"], ids[^1]);
917+
918+
var tokens = bpeTokenizer.EncodeToTokens(text, out _, considerPreTokenization: false).Select(t => t.Value).ToArray();
919+
Assert.Equal(["<|sos|>", "Hello", "unk", "world", "!", "unk", "<|eos|>"], tokens);
920+
921+
Assert.Equal("<|sos|>Hellounkworld!unk<|eos|>", bpeTokenizer.Decode(ids));
922+
Assert.Equal("Helloworld!", bpeTokenizer.Decode(ids, considerSpecialTokens: false));
923+
924+
BpeOptions options1 = new BpeOptions(options.Vocabulary)
925+
{
926+
// Null UnknownToken means no unknown token support
927+
Merges = options.Merges,
928+
SpecialTokens = options.SpecialTokens,
929+
BeginningOfSentenceToken = options.BeginningOfSentenceToken,
930+
EndOfSentenceToken = options.EndOfSentenceToken
931+
};
932+
933+
bpeTokenizer = BpeTokenizer.Create(options1);
934+
ids = bpeTokenizer.EncodeToIds(text, considerPreTokenization: false);
935+
936+
// Because Unknown is not supported in this encoding, the encoding will produce different encoding results
937+
Assert.Equal([50257, 39, 5037, 1764, 0, 50258], ids);
938+
Assert.Equal(options.SpecialTokens["<|sos|>"], ids[0]);
939+
Assert.Equal(options.SpecialTokens["<|eos|>"], ids[^1]);
940+
941+
tokens = bpeTokenizer.EncodeToTokens(text, out _, considerPreTokenization: false).Select(t => t.Value).ToArray();
942+
Assert.Equal(["<|sos|>", "H", "ellow", "orld", "!", "<|eos|>"], tokens);
943+
944+
Assert.Equal("<|sos|>Helloworld!<|eos|>", bpeTokenizer.Decode(ids));
945+
Assert.Equal("Helloworld!", bpeTokenizer.Decode(ids, considerSpecialTokens: false));
946+
}
947+
888948
private static BpeTokenizer CreateBpeTokenizerFromJson()
889949
{
890950
// @"https://huggingface.co/deepseek-ai/DeepSeek-R1/resolve/main/tokenizer.json?download=true"

0 commit comments

Comments
 (0)