-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Address the feedback on the tokenizer's library #7024
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
f6e32f5
Fix cache when calling EncodeToIds
tarekgh 0553922
Make EnglishRoberta _mergeRanks thread safe
tarekgh a4cb1f5
Delete Trainer
tarekgh 6a13025
Remove the setters on the Bpe properties
tarekgh 3278aff
Remove Roberta and Tiktoken special casing in the Tokenizer and suppo…
tarekgh b5f7fa2
Support text-embedding-3-small/large embedding
tarekgh a11f4e0
Remove redundant TokenToId abstraction and keep the one with the extr…
tarekgh 865068a
Enable creating Tiktoken asynchronously or directly using the tokeniz…
tarekgh 4077de0
Add cancellationToken support in CreateAsync APIs
tarekgh 5aaf849
Rename sequence to text and Tokenize to Encode
tarekgh b5e0927
Rename skipSpecialTokens to considerSpecialTokens
tarekgh 5e26b3e
Rename TokenizerResult to EncodingResult
tarekgh 985de8a
Make Token publicly immutable
tarekgh b551e7d
Change offset tuples from (Index, End) to (Index, Length)
tarekgh 7ea7f70
Rename NormalizedString method's parameters
tarekgh b0c8244
Rename Model's methods to start with verb
tarekgh 450418a
Convert Model.GetVocab() method to a Vocab property
tarekgh 6f53de8
Some method's parameters and variable renaming
tarekgh 62334c6
Remove Vocab and VocabSize from the abstraction
tarekgh d48b32d
Cleanup normalization support
tarekgh 191ab03
Minor Bpe cleanup
tarekgh b9b0f58
Resolve rebase change
tarekgh 1ad157f
Address the feedback
tarekgh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,7 @@ public sealed class EnglishRoberta : Model | |
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence; | ||
private readonly IReadOnlyDictionary<string, int> _vocab; | ||
private readonly SortedDictionary<int, string> _vocabReverse; | ||
private readonly Dictionary<(string, string), int> _mergeRanks; | ||
private readonly Cache<(string, string), int> _mergeRanks; | ||
private readonly IReadOnlyDictionary<char, char> _byteToUnicode; | ||
private readonly IReadOnlyDictionary<char, char> _unicodeToByte; | ||
private readonly string[] _charToString; | ||
|
@@ -205,6 +205,11 @@ public override string[] Save(string path, string? prefix = null) | |
/// <returns>The list of tokens generated from the sequence tokenization.</returns> | ||
public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false) | ||
{ | ||
if (string.IsNullOrEmpty(sequence)) | ||
{ | ||
return Bpe.EmptyTokensList; | ||
} | ||
|
||
char[] token = ArrayPool<char>.Shared.Rent(sequence.Length); | ||
int[] indexMapping = ArrayPool<int>.Shared.Rent(sequence.Length); | ||
|
||
|
@@ -258,6 +263,11 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok | |
|
||
private int TokenizeToIds(string sequence, IList<int>? accumulatedIds) | ||
{ | ||
if (string.IsNullOrEmpty(sequence)) | ||
{ | ||
return 0; | ||
} | ||
|
||
if (_cache.TryGet(sequence, out List<Token>? hit)) | ||
{ | ||
if (accumulatedIds is not null) | ||
|
@@ -271,34 +281,17 @@ private int TokenizeToIds(string sequence, IList<int>? accumulatedIds) | |
return hit.Count; | ||
} | ||
|
||
Span<char> token = stackalloc char[100]; | ||
Span<int> indexMapping = stackalloc int[100]; | ||
|
||
if (sequence.Length > 100) | ||
{ | ||
token = new char[sequence.Length].AsSpan(); | ||
indexMapping = new int[sequence.Length].AsSpan(); | ||
} | ||
|
||
int newTokenIndex = 0; | ||
for (int i = 0; i < sequence.Length; i++) | ||
// If the cache doesn't have the sequence, then tokenize it and add it to the cache | ||
IReadOnlyList<Token> tokens = Tokenize(sequence); | ||
if (accumulatedIds is not null) | ||
{ | ||
if (_byteToUnicode.TryGetValue(sequence[i], out var value)) | ||
foreach (var t in tokens) | ||
{ | ||
token[newTokenIndex] = value; | ||
indexMapping[newTokenIndex] = i; | ||
newTokenIndex++; | ||
accumulatedIds.Add(t.Id); | ||
} | ||
} | ||
|
||
if (newTokenIndex == 0) | ||
{ | ||
return 0; | ||
} | ||
|
||
List<Token> result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping); | ||
_cache.Set(sequence, result); | ||
return result.Count; | ||
return tokens.Count; | ||
} | ||
|
||
/// <summary> | ||
|
@@ -477,9 +470,9 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream) | |
return vocab; | ||
} | ||
|
||
private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream) | ||
private Cache<(string, string), int> GetMergeRanks(Stream mergeStream) | ||
{ | ||
var mergeRanks = new Dictionary<(string, string), int>(); | ||
var mergeRanks = new Cache<(string, string), int>(60_000); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does this 60k come from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loaded data from the merge file is 50K. I give it 10K more to grow. |
||
try | ||
{ | ||
using StreamReader reader = new StreamReader(mergeStream); | ||
|
@@ -500,7 +493,7 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream) | |
throw new Exception($"Invalid format of merge file: \"{line}\""); | ||
} | ||
|
||
mergeRanks.Add((line.Substring(0, index), line.Substring(index + 1)), rank++); | ||
mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++); | ||
} | ||
} | ||
catch (Exception e) | ||
|
@@ -538,26 +531,19 @@ private static int GetByteToUnicode(out IReadOnlyDictionary<char, char> byteToUn | |
} | ||
|
||
/// <summary> | ||
/// Encode a token into BPE-ed Ids. E.g., "playing" into ["play", "ing"]. | ||
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"]. | ||
/// </summary> | ||
/// <param name="token">The token to encode.</param> | ||
/// <param name="ids">The list of Ids to encode the token into.</param> | ||
/// <returns>The number of encoded ids.</returns> | ||
private int EncodeToIds(Span<char> token, IList<int>? ids) | ||
private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping) | ||
{ | ||
if (token.Length == 0) | ||
{ | ||
return 0; | ||
return Bpe.EmptyTokensList; | ||
} | ||
|
||
if (token.Length == 1) | ||
{ | ||
if (ids is not null) | ||
{ | ||
ids.Add(_vocab[_charToString[token[0]]]); | ||
} | ||
|
||
return 1; | ||
string tokenValue = _charToString[token[0]]; | ||
tarekgh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[0] + 1)) }; | ||
} | ||
|
||
List<string> word = new(token.Length); | ||
|
@@ -586,7 +572,7 @@ private int EncodeToIds(Span<char> token, IList<int>? ids) | |
|
||
// get the most frequent bi-gram pair | ||
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue)); | ||
if (!_mergeRanks.ContainsKey((first, second))) | ||
if (!_mergeRanks.TryGet((first, second), out int _)) | ||
{ | ||
break; | ||
} | ||
|
@@ -605,6 +591,7 @@ private int EncodeToIds(Span<char> token, IList<int>? ids) | |
{ | ||
newWord.Add(word[k]); | ||
} | ||
|
||
break; | ||
} | ||
else | ||
|
@@ -614,104 +601,7 @@ private int EncodeToIds(Span<char> token, IList<int>? ids) | |
{ | ||
newWord.Add(word[k]); | ||
} | ||
i = j; | ||
} | ||
|
||
// check the next element is {second} or not | ||
if (i < word.Count - 1 && word[i + 1] == second) | ||
{ | ||
newWord.Add(first + second); | ||
i += 2; | ||
} | ||
else | ||
{ | ||
newWord.Add(word[i]); | ||
i += 1; | ||
} | ||
} | ||
|
||
List<string> temp = word; | ||
word = newWord; | ||
newWord = temp; | ||
newWord.Clear(); | ||
|
||
// otherwise, continue merging | ||
WordToPairs(word, pairs); | ||
} | ||
|
||
if (ids is not null) | ||
{ | ||
foreach (string w in word) | ||
{ | ||
ids.Add(_vocab[w]); | ||
} | ||
} | ||
|
||
return word.Count; | ||
} | ||
|
||
/// <summary> | ||
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"]. | ||
/// </summary> | ||
private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping) | ||
{ | ||
if (token.Length == 0) | ||
{ | ||
return Bpe.EmptyTokensList; | ||
} | ||
|
||
if (token.Length == 1) | ||
{ | ||
string tokenValue = _charToString[token[0]]; | ||
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[0] + 1)) }; | ||
} | ||
|
||
List<string> word = new(token.Length); | ||
foreach (char c in token) | ||
{ | ||
Debug.Assert(c < _charToString.Length); | ||
word.Add(_charToString[c]); | ||
} | ||
|
||
HashSet<(string, string)> pairs = new(); | ||
|
||
WordToPairs(word, pairs); | ||
|
||
var newWord = new List<string>(); | ||
|
||
Debug.Assert(pairs.Count != 0, "Pairs should not be empty."); | ||
|
||
while (true) | ||
{ | ||
/* while conditions */ | ||
// if only one element left, merge is finished (with the whole word merged) | ||
if (word.Count == 1) | ||
{ | ||
break; | ||
} | ||
|
||
// get the most frequent bi-gram pair | ||
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue)); | ||
if (!_mergeRanks.ContainsKey((first, second))) | ||
{ | ||
break; | ||
} | ||
/* end while conditions */ | ||
|
||
// search and merge all (first, second) pairs in {word} | ||
var i = 0; | ||
while (i < word.Count) | ||
{ | ||
// find the next occurrence of {first} and add the elements before into {newWord} | ||
var j = word.IndexOf(first, i); | ||
if (j == -1) | ||
{ | ||
newWord.AddRange(word.Skip(i)); | ||
break; | ||
} | ||
else | ||
{ | ||
newWord.AddRange(word.Skip(i).Take(j - i)); | ||
i = j; | ||
} | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.