Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address the feedback on the tokenizer's library #7024

Merged
merged 23 commits into from
Feb 26, 2024
Merged
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6e32f5
Fix cache when calling EncodeToIds
tarekgh Feb 17, 2024
0553922
Make EnglishRoberta _mergeRanks thread safe
tarekgh Feb 17, 2024
a4cb1f5
Delete Trainer
tarekgh Feb 19, 2024
6a13025
Remove the setters on the Bpe properties
tarekgh Feb 19, 2024
3278aff
Remove Roberta and Tiktoken special casing in the Tokenizer and suppo…
tarekgh Feb 19, 2024
b5f7fa2
Support text-embedding-3-small/large embedding
tarekgh Feb 19, 2024
a11f4e0
Remove redundant TokenToId abstraction and keep the one with the extr…
tarekgh Feb 19, 2024
865068a
Enable creating Tiktoken asynchronously or directly using the tokeniz…
tarekgh Feb 20, 2024
4077de0
Add cancellationToken support in CreateAsync APIs
tarekgh Feb 21, 2024
5aaf849
Rename sequence to text and Tokenize to Encode
tarekgh Feb 21, 2024
b5e0927
Rename skipSpecialTokens to considerSpecialTokens
tarekgh Feb 21, 2024
5e26b3e
Rename TokenizerResult to EncodingResult
tarekgh Feb 21, 2024
985de8a
Make Token publicly immutable
tarekgh Feb 21, 2024
b551e7d
Change offset tuples from (Index, End) to (Index, Length)
tarekgh Feb 21, 2024
7ea7f70
Rename NormalizedString method's parameters
tarekgh Feb 21, 2024
b0c8244
Rename Model's methods to start with verb
tarekgh Feb 21, 2024
450418a
Convert Model.GetVocab() method to a Vocab property
tarekgh Feb 21, 2024
6f53de8
Some method's parameters and variable renaming
tarekgh Feb 22, 2024
62334c6
Remove Vocab and VocabSize from the abstraction
tarekgh Feb 22, 2024
d48b32d
Cleanup normalization support
tarekgh Feb 22, 2024
191ab03
Minor Bpe cleanup
tarekgh Feb 22, 2024
b9b0f58
Resolve rebase change
tarekgh Feb 23, 2024
1ad157f
Address the feedback
tarekgh Feb 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 41 additions & 54 deletions src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ public sealed class Tiktoken : Model
/// <summary>
/// Create a new Tiktoken tokenizer's model object.
/// </summary>
/// <param name="tikTokenBpeFile">The path to the BPE rank file.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="vocabFilePath">The path to the BPE vocab file.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="tikTokenBpeFile"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE rank file.</exception>
public Tiktoken(string tikTokenBpeFile, IReadOnlyDictionary<string, int>? specialTokensEncoder = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
this(string.IsNullOrEmpty(tikTokenBpeFile) ? throw new ArgumentNullException(nameof(tikTokenBpeFile)) : File.OpenRead(tikTokenBpeFile), specialTokensEncoder, cacheSize, disposeStream: true)
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabFilePath"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
public Tiktoken(string vocabFilePath, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
this(string.IsNullOrEmpty(vocabFilePath) ? throw new ArgumentNullException(nameof(vocabFilePath)) : File.OpenRead(vocabFilePath), specialTokens, cacheSize, disposeStream: true)
{
}

/// <summary>
/// Create a new Tiktoken tokenizer's model object.
/// </summary>
/// <param name="tikTokenBpeFileStream">The stream to the BPE rank file.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <exception cref="ArgumentNullException">Thrown when <paramref name="tikTokenBpeFileStream"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE rank file.</exception>
public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>? specialTokensEncoder = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
this(tikTokenBpeFileStream ?? throw new ArgumentNullException(nameof(tikTokenBpeFileStream)), specialTokensEncoder, cacheSize, disposeStream: false)
/// <exception cref="ArgumentNullException">Thrown when <paramref name="vocabStream"/> is null or empty.</exception>
/// <exception cref="InvalidOperationException">Thrown when failed to load the BPE vocab file.</exception>
public Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens = null, int cacheSize = LruCache<string, int[]>.DefaultCacheSize) :
this(vocabStream ?? throw new ArgumentNullException(nameof(vocabStream)), specialTokens, cacheSize, disposeStream: false)
{
}

Expand All @@ -58,13 +58,13 @@ public Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>?
/// <param name="encoder">The dictionary mapping token utf-8 bytes to Ids.</param>
/// <param name="decoder">The dictionary mapping Ids to token utf-8 bytes.</param>
/// <param name="vocab">The dictionary mapping string tokens to Ids.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The max size of the cache to use.</param>
public Tiktoken(
IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder,
IReadOnlyDictionary<int, byte[]>? decoder,
IReadOnlyDictionary<string, int>? vocab,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
IReadOnlyDictionary<string, int>? specialTokens = null,
int cacheSize = LruCache<string, int[]>.DefaultCacheSize) : this(cacheSize)
{
if (encoder is null)
Expand Down Expand Up @@ -94,7 +94,7 @@ public Tiktoken(
string s = Encoding.UTF8.GetString(kvp.Key.ToArray());
tarekgh marked this conversation as resolved.
Show resolved Hide resolved

// Don't add mal-formed utf8 converted bytes to the vocab.
if (!StringContainInvalidChars(s))
if (!s.Contains('\uFFFD'))
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
{
vocab1[s] = kvp.Value;
}
Expand All @@ -104,20 +104,20 @@ public Tiktoken(

_vocab = vocab;

_specialTokensEncoder = specialTokensEncoder;
_specialTokensEncoder = specialTokens;
if (_specialTokensEncoder is not null)
{
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
}
}

private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>? specialTokensEncoder, int cacheSize, bool disposeStream) : this(cacheSize)
private Tiktoken(Stream vocabStream, IReadOnlyDictionary<string, int>? specialTokens, int cacheSize, bool disposeStream) : this(cacheSize)
{
try
{
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: false).GetAwaiter().GetResult();
(_encoder, _vocab, _decoder) = LoadTikTokenBpeAsync(vocabStream, useAsync: false).GetAwaiter().GetResult();

_specialTokensEncoder = specialTokensEncoder;
_specialTokensEncoder = specialTokens;
if (_specialTokensEncoder is not null)
{
_specialTokensDecoder = _specialTokensEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
Expand All @@ -127,7 +127,7 @@ private Tiktoken(Stream tikTokenBpeFileStream, IReadOnlyDictionary<string, int>?
{
if (disposeStream)
{
tikTokenBpeFileStream.Dispose();
vocabStream.Dispose();
}
}
}
Expand All @@ -148,69 +148,69 @@ private Tiktoken(int cacheSize)
/// <summary>
/// Create a new Tiktoken tokenizer's model object asynchronously.
/// </summary>
/// <param name="tikTokenBpeFileStream">The stream to the BPE rank file.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="vocabStream">The stream to the BPE vocab file.</param>
/// <param name="specialTokens">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Tiktoken tokenizer's object.</returns>
public static async Task<Tiktoken> CreateAsync(
Stream tikTokenBpeFileStream,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
Stream vocabStream,
IReadOnlyDictionary<string, int>? specialTokens = null,
int cacheSize = LruCache<string, int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (tikTokenBpeFileStream is null)
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(tikTokenBpeFileStream));
throw new ArgumentNullException(nameof(vocabStream));
}

(IReadOnlyDictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, IReadOnlyDictionary<int, byte[]> decoder) =
await LoadTikTokenBpeAsync(tikTokenBpeFileStream, useAsync: true, cancellationToken).ConfigureAwait(false);
await LoadTikTokenBpeAsync(vocabStream, useAsync: true, cancellationToken).ConfigureAwait(false);

return new Tiktoken(encoder, decoder, vocab, specialTokensEncoder, cacheSize);
return new Tiktoken(encoder, decoder, vocab, specialTokens, cacheSize);
}

/// <summary>
/// Create a new Tiktoken tokenizer's object asynchronously.
/// </summary>
/// <param name="tikTokenBpeFile">The BPE rank file.</param>
/// <param name="vocabFilePath">The BPE vocab file.</param>
/// <param name="specialTokensEncoder">The dictionary mapping special tokens to Ids.</param>
/// <param name="cacheSize">The size of the cache to use.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Tiktoken tokenizer's model object.</returns>
public static async Task<Tiktoken> CreateAsync(
string tikTokenBpeFile,
string vocabFilePath,
IReadOnlyDictionary<string, int>? specialTokensEncoder = null,
int cacheSize = LruCache<string, int[]>.DefaultCacheSize,
CancellationToken cancellationToken = default)
{
if (tikTokenBpeFile is null)
if (vocabFilePath is null)
{
throw new ArgumentNullException(nameof(tikTokenBpeFile));
throw new ArgumentNullException(nameof(vocabFilePath));
}

using Stream tikTokenBpeFileStream = File.OpenRead(tikTokenBpeFile);
return await CreateAsync(tikTokenBpeFileStream, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
using Stream vocabStream = File.OpenRead(vocabFilePath);
return await CreateAsync(vocabStream, specialTokensEncoder, cacheSize, cancellationToken).ConfigureAwait(false);
}

/// <summary>
/// Load BPE rank dictionary from a stream.
/// Load BPE vocab dictionary from a stream.
/// </summary>
/// <param name="tikTokenBpeFileStream">Stream to the BPE rank file</param>
/// <param name="vocabStream">Stream to the BPE vocab file</param>
/// <param name="useAsync">Whether to perform I/O synchronously or asynchronously.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Map of byte[] to integer token id</returns>
/// <exception cref="InvalidOperationException"></exception>
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(
Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default)
Stream vocabStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary<string, int>();
var decoder = new Dictionary<int, byte[]>();

try
{
using (StreamReader reader = new StreamReader(tikTokenBpeFileStream))
using (StreamReader reader = new StreamReader(vocabStream))
{
while (true)
{
Expand All @@ -229,7 +229,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
int spaceIndex = line.IndexOf(' ');
if (spaceIndex <= 0 || spaceIndex >= line.Length - 1 || line.IndexOf(' ', spaceIndex + 1) >= 0)
{
throw new FormatException($"Invalid format in the BPE encoder file stream");
throw new FormatException($"Invalid format in the BPE vocab file stream");
}

if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank))
Expand All @@ -241,7 +241,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :

string decodedToken = Encoding.UTF8.GetString(tokenBytes);

if (!StringContainInvalidChars(decodedToken))
if (!decodedToken.Contains('\uFFFD'))
{
vocab[decodedToken] = rank;
}
Expand All @@ -255,7 +255,7 @@ await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
}
catch (Exception ex)
{
throw new InvalidOperationException($"Failed to load from BPE encoder file stream: {ex.Message}", ex);
throw new InvalidOperationException($"Failed to load from BPE vocab file stream: {ex.Message}", ex);
}

return (encoder, vocab, decoder);
Expand Down Expand Up @@ -643,18 +643,5 @@ private static unsafe string GetString(ReadOnlySpan<byte> utf8Bytes)
}
#endif
}

private static bool StringContainInvalidChars(string text)
{
for (int i = 0; i < text.Length; i++)
{
if (text[i] == 0xFFFD)
{
return true;
}
}

return false;
}
}
}