Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address the feedback on the tokenizer's library #7024

Merged
merged 23 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6e32f5
Fix cache when calling EncodeToIds
tarekgh Feb 17, 2024
0553922
Make EnglishRoberta _mergeRanks thread safe
tarekgh Feb 17, 2024
a4cb1f5
Delete Trainer
tarekgh Feb 19, 2024
6a13025
Remove the setters on the Bpe properties
tarekgh Feb 19, 2024
3278aff
Remove Roberta and Tiktoken special casing in the Tokenizer and suppo…
tarekgh Feb 19, 2024
b5f7fa2
Support text-embedding-3-small/large embedding
tarekgh Feb 19, 2024
a11f4e0
Remove redundant TokenToId abstraction and keep the one with the extr…
tarekgh Feb 19, 2024
865068a
Enable creating Tiktoken asynchronously or directly using the tokeniz…
tarekgh Feb 20, 2024
4077de0
Add cancellationToken support in CreateAsync APIs
tarekgh Feb 21, 2024
5aaf849
Rename sequence to text and Tokenize to Encode
tarekgh Feb 21, 2024
b5e0927
Rename skipSpecialTokens to considerSpecialTokens
tarekgh Feb 21, 2024
5e26b3e
Rename TokenizerResult to EncodingResult
tarekgh Feb 21, 2024
985de8a
Make Token publicly immutable
tarekgh Feb 21, 2024
b551e7d
Change offset tuples from (Index, End) to (Index, Length)
tarekgh Feb 21, 2024
7ea7f70
Rename NormalizedString method's parameters
tarekgh Feb 21, 2024
b0c8244
Rename Model's methods to start with verb
tarekgh Feb 21, 2024
450418a
Convert Model.GetVocab() method to a Vocab property
tarekgh Feb 21, 2024
6f53de8
Some method's parameters and variable renaming
tarekgh Feb 22, 2024
62334c6
Remove Vocab and VocabSize from the abstraction
tarekgh Feb 22, 2024
d48b32d
Cleanup normalization support
tarekgh Feb 22, 2024
191ab03
Minor Bpe cleanup
tarekgh Feb 22, 2024
b9b0f58
Resolve rebase change
tarekgh Feb 23, 2024
1ad157f
Address the feedback
tarekgh Feb 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;

Expand Down Expand Up @@ -95,5 +96,40 @@ internal void Set(TKey k, TValue v)
}
finally { _cacheLock.ExitWriteLock(); }
}

internal KeyValuePair<TKey, TValue>[] ToArray()
{
_cacheLock.EnterReadLock();
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
try
{
return Map.ToArray();
}
finally { _cacheLock.ExitReadLock(); }
}

internal TValue GetOrAdd(TKey key, TValue value)
{
_cacheLock.EnterUpgradeableReadLock();
try
{
if (Map.TryGetValue(key, out TValue? v))
{
return v;
}

_cacheLock.EnterWriteLock();
try
{
if (Capacity > Map.Count)
{
Map[key] = value;
}
}
finally { _cacheLock.ExitWriteLock(); }

return value;
}
finally { _cacheLock.ExitUpgradeableReadLock(); }
}
}
}
164 changes: 27 additions & 137 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public sealed class EnglishRoberta : Model
private readonly HighestOccurrenceMapping _vocabIdToHighestOccurrence;
private readonly IReadOnlyDictionary<string, int> _vocab;
private readonly SortedDictionary<int, string> _vocabReverse;
private readonly Dictionary<(string, string), int> _mergeRanks;
private readonly Cache<(string, string), int> _mergeRanks;
private readonly IReadOnlyDictionary<char, char> _byteToUnicode;
private readonly IReadOnlyDictionary<char, char> _unicodeToByte;
private readonly string[] _charToString;
Expand Down Expand Up @@ -205,6 +205,11 @@ public override string[] Save(string path, string? prefix = null)
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false)
{
if (string.IsNullOrEmpty(sequence))
{
return Bpe.EmptyTokensList;
}

char[] token = ArrayPool<char>.Shared.Rent(sequence.Length);
int[] indexMapping = ArrayPool<int>.Shared.Rent(sequence.Length);

Expand Down Expand Up @@ -258,6 +263,11 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok

private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
{
if (string.IsNullOrEmpty(sequence))
{
return 0;
}

if (_cache.TryGet(sequence, out List<Token>? hit))
{
if (accumulatedIds is not null)
Expand All @@ -271,34 +281,17 @@ private int TokenizeToIds(string sequence, IList<int>? accumulatedIds)
return hit.Count;
}

Span<char> token = stackalloc char[100];
Span<int> indexMapping = stackalloc int[100];

if (sequence.Length > 100)
{
token = new char[sequence.Length].AsSpan();
indexMapping = new int[sequence.Length].AsSpan();
}

int newTokenIndex = 0;
for (int i = 0; i < sequence.Length; i++)
// If the cache doesn't have the sequence, then tokenize it and add it to the cache
IReadOnlyList<Token> tokens = Tokenize(sequence);
if (accumulatedIds is not null)
{
if (_byteToUnicode.TryGetValue(sequence[i], out var value))
foreach (var t in tokens)
{
token[newTokenIndex] = value;
indexMapping[newTokenIndex] = i;
newTokenIndex++;
accumulatedIds.Add(t.Id);
}
}

if (newTokenIndex == 0)
{
return 0;
}

List<Token> result = EncodeToTokens(token.Slice(0, newTokenIndex), indexMapping);
_cache.Set(sequence, result);
return result.Count;
return tokens.Count;
}

/// <summary>
Expand Down Expand Up @@ -477,9 +470,9 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
return vocab;
}

private Dictionary<(string, string), int> GetMergeRanks(Stream mergeStream)
private Cache<(string, string), int> GetMergeRanks(Stream mergeStream)
{
var mergeRanks = new Dictionary<(string, string), int>();
var mergeRanks = new Cache<(string, string), int>(60_000);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this 60k come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loaded data from the merge file is 50K. I give it 10K more to grow.

try
{
using StreamReader reader = new StreamReader(mergeStream);
Expand All @@ -500,7 +493,7 @@ private Dictionary<string, int> GetVocabulary(Stream vocabularyStream)
throw new Exception($"Invalid format of merge file: \"{line}\"");
}

mergeRanks.Add((line.Substring(0, index), line.Substring(index + 1)), rank++);
mergeRanks.Set((line.Substring(0, index), line.Substring(index + 1)), rank++);
}
}
catch (Exception e)
Expand Down Expand Up @@ -538,26 +531,19 @@ private static int GetByteToUnicode(out IReadOnlyDictionary<char, char> byteToUn
}

/// <summary>
/// Encode a token into BPE-ed Ids. E.g., "playing" into ["play", "ing"].
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
/// </summary>
/// <param name="token">The token to encode.</param>
/// <param name="ids">The list of Ids to encode the token into.</param>
/// <returns>The number of encoded ids.</returns>
private int EncodeToIds(Span<char> token, IList<int>? ids)
private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping)
{
if (token.Length == 0)
{
return 0;
return Bpe.EmptyTokensList;
}

if (token.Length == 1)
{
if (ids is not null)
{
ids.Add(_vocab[_charToString[token[0]]]);
}

return 1;
string tokenValue = _charToString[token[0]];
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[0] + 1)) };
}

List<string> word = new(token.Length);
Expand Down Expand Up @@ -586,7 +572,7 @@ private int EncodeToIds(Span<char> token, IList<int>? ids)

// get the most frequent bi-gram pair
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
if (!_mergeRanks.ContainsKey((first, second)))
if (!_mergeRanks.TryGet((first, second), out int _))
{
break;
}
Expand All @@ -605,6 +591,7 @@ private int EncodeToIds(Span<char> token, IList<int>? ids)
{
newWord.Add(word[k]);
}

break;
}
else
Expand All @@ -614,104 +601,7 @@ private int EncodeToIds(Span<char> token, IList<int>? ids)
{
newWord.Add(word[k]);
}
i = j;
}

// check the next element is {second} or not
if (i < word.Count - 1 && word[i + 1] == second)
{
newWord.Add(first + second);
i += 2;
}
else
{
newWord.Add(word[i]);
i += 1;
}
}

List<string> temp = word;
word = newWord;
newWord = temp;
newWord.Clear();

// otherwise, continue merging
WordToPairs(word, pairs);
}

if (ids is not null)
{
foreach (string w in word)
{
ids.Add(_vocab[w]);
}
}

return word.Count;
}

/// <summary>
/// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"].
/// </summary>
private List<Token> EncodeToTokens(Span<char> token, Span<int> indexMapping)
{
if (token.Length == 0)
{
return Bpe.EmptyTokensList;
}

if (token.Length == 1)
{
string tokenValue = _charToString[token[0]];
return new List<Token> { new Token(_vocab[tokenValue], tokenValue, (indexMapping[0], indexMapping[0] + 1)) };
}

List<string> word = new(token.Length);
foreach (char c in token)
{
Debug.Assert(c < _charToString.Length);
word.Add(_charToString[c]);
}

HashSet<(string, string)> pairs = new();

WordToPairs(word, pairs);

var newWord = new List<string>();

Debug.Assert(pairs.Count != 0, "Pairs should not be empty.");

while (true)
{
/* while conditions */
// if only one element left, merge is finished (with the whole word merged)
if (word.Count == 1)
{
break;
}

// get the most frequent bi-gram pair
var (first, second) = pairs.ArgMin(pair => _mergeRanks.GetOrAdd(pair, int.MaxValue));
if (!_mergeRanks.ContainsKey((first, second)))
{
break;
}
/* end while conditions */

// search and merge all (first, second) pairs in {word}
var i = 0;
while (i < word.Count)
{
// find the next occurrence of {first} and add the elements before into {newWord}
var j = word.IndexOf(first, i);
if (j == -1)
{
newWord.AddRange(word.Skip(i));
break;
}
else
{
newWord.AddRange(word.Skip(i).Take(j - i));
i = j;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
{ "code-search-babbage-code-001", ModelEncoding.R50kBase },
{ "code-search-ada-code-001", ModelEncoding.R50kBase },

//open source
// open source
{ "gpt2", ModelEncoding.GPT2 }
};

Expand Down
44 changes: 40 additions & 4 deletions test/Microsoft.ML.Tokenizers.Tests/EnglishRobertaTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ public async void TokenizationTest()
using Stream translationStream = File.OpenRead(translationFile);
tokenizer = new Tokenizer(new EnglishRoberta(vocabStream, mergeStream, translationStream), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer);

// Ensure caching works regardless of which method is called first.
for (CallingOrder order = CallingOrder.Encode; order <= CallingOrder.CountTokens; order++)
{
tokenizer = new Tokenizer(new EnglishRoberta(vocabFile, mergeFile, translationFile), RobertaPreTokenizer.Instance);
TestTokenizer(tokenizer, order);
}
}
finally
{
Expand All @@ -122,17 +129,46 @@ public async void TokenizationTest()
}
}

private void TestTokenizer(Tokenizer tokenizer)
private enum CallingOrder
{
Encode,
EncodeToIds,
CountTokens
}

// Calling EncodeToIds after calling Encode will cause EncodeToIds uses the cached data from the previous Encode call.
// Calling with callIdsFirst = true will test the other way around.
private void TestTokenizer(Tokenizer tokenizer, CallingOrder callingOrder = CallingOrder.Encode)
{
Assert.NotNull(tokenizer.Model);
Assert.True(tokenizer.Model is EnglishRoberta);
Assert.True(tokenizer.PreTokenizer is RobertaPreTokenizer);

foreach (object[] p in BertaData)
{
TokenizerResult encoding = tokenizer.Encode((string)p[0]);
IReadOnlyList<int> ids = tokenizer.EncodeToIds((string)p[0]);
int idsCount = tokenizer.CountTokens((string)p[0]);
IReadOnlyList<int> ids;
TokenizerResult encoding;
int idsCount;

if (callingOrder == CallingOrder.Encode)
{
encoding = tokenizer.Encode((string)p[0]);
ids = tokenizer.EncodeToIds((string)p[0]);
idsCount = tokenizer.CountTokens((string)p[0]);
}
else if (callingOrder == CallingOrder.EncodeToIds)
{
ids = tokenizer.EncodeToIds((string)p[0]);
encoding = tokenizer.Encode((string)p[0]);
idsCount = tokenizer.CountTokens((string)p[0]);
}
else // CountTokens
{
idsCount = tokenizer.CountTokens((string)p[0]);
ids = tokenizer.EncodeToIds((string)p[0]);
encoding = tokenizer.Encode((string)p[0]);
}

Assert.Equal(p[1], encoding.Ids);
Assert.Equal(p[1], ids);
Assert.Equal(((int[])p[1]).Length, idsCount);
Expand Down