Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address the feedback on the tokenizer's library #7024

Merged
merged 23 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6e32f5
Fix cache when calling EncodeToIds
tarekgh Feb 17, 2024
0553922
Make EnglishRoberta _mergeRanks thread safe
tarekgh Feb 17, 2024
a4cb1f5
Delete Trainer
tarekgh Feb 19, 2024
6a13025
Remove the setters on the Bpe properties
tarekgh Feb 19, 2024
3278aff
Remove Roberta and Tiktoken special casing in the Tokenizer and suppo…
tarekgh Feb 19, 2024
b5f7fa2
Support text-embedding-3-small/large embedding
tarekgh Feb 19, 2024
a11f4e0
Remove redundant TokenToId abstraction and keep the one with the extr…
tarekgh Feb 19, 2024
865068a
Enable creating Tiktoken asynchronously or directly using the tokeniz…
tarekgh Feb 20, 2024
4077de0
Add cancellationToken support in CreateAsync APIs
tarekgh Feb 21, 2024
5aaf849
Rename sequence to text and Tokenize to Encode
tarekgh Feb 21, 2024
b5e0927
Rename skipSpecialTokens to considerSpecialTokens
tarekgh Feb 21, 2024
5e26b3e
Rename TokenizerResult to EncodingResult
tarekgh Feb 21, 2024
985de8a
Make Token publicly immutable
tarekgh Feb 21, 2024
b551e7d
Change offset tuples from (Index, End) to (Index, Length)
tarekgh Feb 21, 2024
7ea7f70
Rename NormalizedString method's parameters
tarekgh Feb 21, 2024
b0c8244
Rename Model's methods to start with verb
tarekgh Feb 21, 2024
450418a
Convert Model.GetVocab() method to a Vocab property
tarekgh Feb 21, 2024
6f53de8
Some method's parameters and variable renaming
tarekgh Feb 22, 2024
62334c6
Remove Vocab and VocabSize from the abstraction
tarekgh Feb 22, 2024
d48b32d
Cleanup normalization support
tarekgh Feb 22, 2024
191ab03
Minor Bpe cleanup
tarekgh Feb 22, 2024
b9b0f58
Resolve rebase change
tarekgh Feb 23, 2024
1ad157f
Address the feedback
tarekgh Feb 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,9 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToToken(int id, bool skipSpecialTokens = false)
public override string? IdToToken(int id, bool skipSpecialTokens = false, bool filterUnsupportedChars = true)
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
{
if (VocabReverse.TryGetValue(id, out string? value))
{
Expand Down
31 changes: 16 additions & 15 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -127,29 +127,30 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToToken(int id, bool skipSpecialTokens = false) =>
skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null;

/// <summary>
/// Map the tokenized Id to the original string while filtering out unsupported characters.
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public string? IdToFilteredToken(int id, bool skipSpecialTokens = false)
public override string? IdToToken(int id, bool skipSpecialTokens = false, bool filterUnsupportedChars = true)
{
if (skipSpecialTokens && id < 0)
{
return null;
}

if (_vocabReverse.TryGetValue(id, out var value))
{
var textChars = string.Join("", value)
.Where(c => _unicodeToByte.ContainsKey(c))
.Select(c => _unicodeToByte[c]);
var text = new string(textChars.ToArray());
return text;
if (filterUnsupportedChars)
{
var textChars = string.Join("", value)
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
.Where(c => _unicodeToByte.ContainsKey(c))
.Select(c => _unicodeToByte[c]);
return new string(textChars.ToArray());
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
}
else
{
return value;
}
}

return null;
Expand Down
23 changes: 22 additions & 1 deletion src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,29 @@ public virtual int CountTokens(string sequence, bool isSpecialToken)
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public abstract string? IdToToken(int id, bool skipSpecialTokens = false);
public abstract string? IdToToken(int id, bool skipSpecialTokens = false, bool filterUnsupportedChars = true);

/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="skipSpecialTokens">Whether the special tokens should be removed from the decoded string.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
public virtual string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool skipSpecialTokens = false, bool filterUnsupportedChars = true)
{
List<string> tokens = new List<string>();

foreach (int id in ids)
{
tokens.Add(IdToToken(id, skipSpecialTokens, filterUnsupportedChars) ?? "");
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
}

return decoder?.Decode(tokens) ?? string.Join("", tokens);
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
Expand Down
19 changes: 17 additions & 2 deletions src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,9 @@ public override int CountTokens(string sequence, bool isSpecialToken)
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToToken(int id, bool skipSpecialTokens = false)
public override string? IdToToken(int id, bool skipSpecialTokens = false, bool filterUnsupportedChars = true)
{
if (!skipSpecialTokens && _specialTokensDecoder is not null && _specialTokensDecoder.TryGetValue(id, out string? token))
{
Expand All @@ -413,8 +414,22 @@ public override int CountTokens(string sequence, bool isSpecialToken)
return null;
}

internal string? IdsToString(IEnumerable<int> ids, bool skipSpecialTokens = false)
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="skipSpecialTokens">Whether the special tokens should be removed from the decoded string.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <param name="decoder">The optional Decoder to merge the given list of tokens in a string.</param>
/// <returns>The decoded string.</returns>
public override string? Decode(IEnumerable<int> ids, TokenizerDecoder? decoder = null, bool skipSpecialTokens = false, bool filterUnsupportedChars = true)
{
// Tiktoken does not ensure a one-to-one mapping between IDs and tokens. Consequently, decoding individual IDs into tokens is not supported;
// instead, decoding all IDs must be done collectively.
// Here is example of case that map one character to multiple Ids:
// '⭐' U-2B50 is mapped to Ids [2928, 99834] in the Tiktoken model.
// In other words, the character '⭐' has UTF-8 code point 0xE2, 0xAD, 0x90, Tiktoken will map 0xE2 to [2928] and 0xAD, 0x90 to [99834].

if (ids is null)
{
return null;
Expand Down
48 changes: 10 additions & 38 deletions src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@ public partial class Tokenizer
/// <param name="model">The Model in use by the Tokenizer.</param>
/// <param name="preTokenizer">The optional PreTokenizer in use by the Tokenizer. WhiteSpace PreTokenizer will be used if this parameter is null.</param>
/// <param name="normalizer">The optional Normalizer in use by the Tokenizer.</param>
public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null)
/// <param name="decoder">The optional Decoder in use by the Tokenizer during the decoding operation to merge the given list of tokens in a string.</param>
public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? normalizer = null, TokenizerDecoder? decoder = null)
{
Model = model;
PreTokenizer = preTokenizer ?? WhiteSpace.Instance;
Normalizer = normalizer;
Decoder = decoder;
}

/// <summary>
Expand All @@ -40,17 +42,17 @@ public Tokenizer(Model model, PreTokenizer? preTokenizer = null, Normalizer? nor
/// <summary>
/// Gets or sets the PreTokenizer used by the Tokenizer.
/// </summary>
public PreTokenizer PreTokenizer { get; set; }
public PreTokenizer PreTokenizer { get; private set; }

/// <summary>
/// Gets or sets the Normalizer in use by the Tokenizer.
/// </summary>
public Normalizer? Normalizer { get; set; }
public Normalizer? Normalizer { get; private set; }

/// <summary>
/// Gets or sets the Decoder in use by the Tokenizer.
/// </summary>
public TokenizerDecoder? Decoder { get; set; }
public TokenizerDecoder? Decoder { get; private set; }

/// <summary>
/// Encodes input text to object has the tokens list, tokens Ids, tokens offset mapping.
Expand Down Expand Up @@ -171,53 +173,23 @@ public int CountTokens(string sequence, bool skipSpecialTokens = false)
return idsCount;
}

// skipSpecialTokens is used in post processing we don't support yet. We are keeping it to allow using it when we support post processing.
/// <summary>
/// Decodes the Id to the mapped token.
/// </summary>
/// <param name="id">The id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The decoded string or null if there is no token mapped to the input id.</returns>
public string? Decode(int id, bool skipSpecialTokens = false) => Model.IdToToken(id, skipSpecialTokens);
public string? Decode(int id, bool skipSpecialTokens = false, bool filterUnsupportedChars = true) => Model.IdToToken(id, skipSpecialTokens, filterUnsupportedChars);

// skipSpecialTokens is used in post processing we don't support yet. We are keeping it to allow using it when we support post processing.
/// <summary>
/// Decode the given ids, back to a String.
/// </summary>
/// <param name="ids">The list of ids that we want to decode.</param>
/// <param name="skipSpecialTokens">Whether the special tokens should be removed from the decoded string.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The decoded string.</returns>
public string? Decode(IEnumerable<int> ids, bool skipSpecialTokens = false)
{
if (Model is Tiktoken tiktoken)
{
// Tiktoken does not ensure a one-to-one mapping between IDs and tokens. Consequently, decoding individual IDs into tokens is not supported;
// instead, decoding all IDs must be done collectively.
// Here is example of case that map one character to multiple Ids:
// '⭐' U-2B50 is mapped to Ids [2928, 99834] in the Tiktoken model.
// In other words, the character '⭐' has UTF-8 code point 0xE2, 0xAD, 0x90, Tiktoken will map 0xE2 to [2928] and 0xAD, 0x90 to [99834].
return tiktoken.IdsToString(ids, skipSpecialTokens);
}

List<string> tokens = new List<string>();

if (Model is EnglishRoberta robertaModel)
{
foreach (int id in ids)
{
tokens.Add(robertaModel.IdToFilteredToken(id, skipSpecialTokens) ?? "");
}
}
else
{
foreach (int id in ids)
{
tokens.Add(Model.IdToToken(id, skipSpecialTokens) ?? "");
}
}

return Decoder?.Decode(tokens) ?? string.Join("", tokens);
}
public string? Decode(IEnumerable<int> ids, bool skipSpecialTokens = false, bool filterUnsupportedChars = true) => Model.Decode(ids, Decoder, skipSpecialTokens, filterUnsupportedChars);

private const string EndOfText = "<|endoftext|>";
private const string FimPrefix = "<|fim_prefix|>";
Expand Down
23 changes: 16 additions & 7 deletions test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public static IEnumerable<object[]> BertaData
new string[] { "Hello", "\u0120Bert", "a" },
new (int, int)[] { (0, 5), (5, 10), (10, 11) },
"Hello Berta",
new int[] { 35245, 144292, 18759122 }
new int[] { 35245, 144292, 18759122 },
new string[] { "Hello", " Bert", "a" },
};

// Intentionally repeating the same case data to test caching.
Expand All @@ -42,7 +43,8 @@ public static IEnumerable<object[]> BertaData
new string[] { "Hello", "\u0120Bert", "a" },
new (int, int)[] { (0, 5), (5, 10), (10, 11) },
"Hello Berta",
new int[] { 35245, 144292, 18759122 }
new int[] { 35245, 144292, 18759122 },
new string[] { "Hello", " Bert", "a" },
};

// Sentence, Expected Ids, Expected Tokens, Expected Offsets, Decoded Tokens, Token occurrence values
Expand All @@ -53,7 +55,8 @@ public static IEnumerable<object[]> BertaData
new string[] { "In", "\u0120the", "\u0120night", "." },
new (int, int)[] { (0, 2), (2, 6), (6, 12), (12, 13) },
"In the night.",
new int[] { 2224123, 800385005, 6062347, 850314647 }
new int[] { 2224123, 800385005, 6062347, 850314647 },
new string[] { "In", " the", " night", "." },
};

// Sentence, Expected Ids, Expected Tokens, Expected Offsets, Decoded Tokens, Token occurrence values
Expand All @@ -64,7 +67,8 @@ public static IEnumerable<object[]> BertaData
new string[] { "He", "llo", "ĠBer", "ta" },
new (int, int)[] { (0, 2), (4, 7), (7, 11), (13, 15) },
"Hello Berta",
new int[] { 2759525, 207306, 565286, 560191 }
new int[] { 2759525, 207306, 565286, 560191 },
new string[] { "He", "llo", " Ber", "ta" },
};

// Sentence, Expected Ids, Expected Tokens, Expected Offsets, Decoded Tokens, Token occurrence values
Expand All @@ -75,7 +79,8 @@ public static IEnumerable<object[]> BertaData
new string[] { },
new (int, int)[] { },
"",
new int[] { }
new int[] { },
new string[] { },
};
}
}
Expand Down Expand Up @@ -180,9 +185,13 @@ private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = Call

for (int i = 0; i < encoding.Tokens.Count; i++)
{
Assert.Equal(encoding.Tokens[i], tokenizer.Model.IdToToken(encoding.Ids[i]));
Assert.Equal(encoding.Tokens[i], tokenizer.Model.IdToToken(encoding.Ids[i], skipSpecialTokens: true, filterUnsupportedChars: false));
Assert.Equal(encoding.Ids[i], tokenizer.Model.TokenToId(encoding.Tokens[i]));
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i]));
Assert.Equal(encoding.Tokens[i], tokenizer.Decode(encoding.Ids[i], skipSpecialTokens: true, filterUnsupportedChars: false));

string[]? filteredToken = p[6] as string[];

Assert.Equal(filteredToken![i], tokenizer.Model.IdToToken(encoding.Ids[i], skipSpecialTokens: true, filterUnsupportedChars: true));
}
}
}
Expand Down