Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address the feedback on the tokenizer's library #7024

Merged
merged 23 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6e32f5
Fix cache when calling EncodeToIds
tarekgh Feb 17, 2024
0553922
Make EnglishRoberta _mergeRanks thread safe
tarekgh Feb 17, 2024
a4cb1f5
Delete Trainer
tarekgh Feb 19, 2024
6a13025
Remove the setters on the Bpe properties
tarekgh Feb 19, 2024
3278aff
Remove Roberta and Tiktoken special casing in the Tokenizer and suppo…
tarekgh Feb 19, 2024
b5f7fa2
Support text-embedding-3-small/large embedding
tarekgh Feb 19, 2024
a11f4e0
Remove redundant TokenToId abstraction and keep the one with the extr…
tarekgh Feb 19, 2024
865068a
Enable creating Tiktoken asynchronously or directly using the tokeniz…
tarekgh Feb 20, 2024
4077de0
Add cancellationToken support in CreateAsync APIs
tarekgh Feb 21, 2024
5aaf849
Rename sequence to text and Tokenize to Encode
tarekgh Feb 21, 2024
b5e0927
Rename skipSpecialTokens to considerSpecialTokens
tarekgh Feb 21, 2024
5e26b3e
Rename TokenizerResult to EncodingResult
tarekgh Feb 21, 2024
985de8a
Make Token publicly immutable
tarekgh Feb 21, 2024
b551e7d
Change offset tuples from (Index, End) to (Index, Length)
tarekgh Feb 21, 2024
7ea7f70
Rename NormalizedString method's parameters
tarekgh Feb 21, 2024
b0c8244
Rename Model's methods to start with verb
tarekgh Feb 21, 2024
450418a
Convert Model.GetVocab() method to a Vocab property
tarekgh Feb 21, 2024
6f53de8
Some method's parameters and variable renaming
tarekgh Feb 22, 2024
62334c6
Remove Vocab and VocabSize from the abstraction
tarekgh Feb 22, 2024
d48b32d
Cleanup normalization support
tarekgh Feb 22, 2024
191ab03
Minor Bpe cleanup
tarekgh Feb 22, 2024
b9b0f58
Resolve rebase change
tarekgh Feb 23, 2024
1ad157f
Address the feedback
tarekgh Feb 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 89 additions & 95 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,87 +54,114 @@ public string? UnknownToken
}

/// <summary>
/// An optional prefix to use on any sub-word that exist only behind another one
/// A prefix to be used for every subword that is not a beginning-of-word
/// </summary>
public string? ContinuingSubwordPrefix { get; set; }
public string? ContinuingSubwordPrefix { get; private set; }

/// <summary>
/// An optional suffix to characterize and end-of-word sub-word
/// </summary>
public string? EndOfWordSuffix { get; set; }
public string? EndOfWordSuffix { get; private set; }

/// <summary>
/// Gets or sets whether allowing multiple unknown tokens get fused
/// </summary>
public bool FuseUnknownTokens { get; set; }
public bool FuseUnknownTokens { get; private set; }


/// <summary>
/// Construct a new Bpe model object with no tokenization vocabulary. This constructor is useful only in the training scenario.
/// Construct a new Bpe model object to use for sentence tokenization.
/// </summary>
public Bpe()
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble understanding what this means.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If encoding text with Bpe model, for the tokens that the model doesn't recognize, it uses the unknown token for it. Most users uses [Unk] for the unknow token. It is possible to get multiple [Unk] tokens next to each others in the result. Settings fuseUnknownTokens to true cause all [Unk] sequence to collapse into one [Ukn]. Fuse term is used by Huggingface and users of Bpe are familiar with that. If you have better explanation we can use here I'll be happy to use it :-)

public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabFile is null ? throw new ArgumentNullException(nameof(vocabFile)) : File.Open(vocabFile, FileMode.Open, FileAccess.Read),
mergesFile is null ? null : File.Open(mergesFile, FileMode.Open, FileAccess.Read), unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: true)
{
Vocab = new();
VocabReverse = new();
Merges = new();

UnknownToken = "[Unk]";
}

/// <summary>
/// Construct a new Bpe model object to use for sentence tokenization and tokenizer training.
/// Construct a new Bpe model object to use for sentence tokenization.
/// </summary>
/// <param name="vocabFile">The JSON file path containing the dictionary of string keys and their ids.</param>
/// <param name="mergesFile">The file path containing the tokens's pairs list.</param>
/// <param name="vocabStream">The JSON stream containing the dictionary of string keys and their ids.</param>
/// <param name="mergesStream">The stream containing the tokens's pairs list.</param>
/// <param name="unknownToken"> The unknown token to be used by the model.</param>
/// <param name="continuingSubwordPrefix">The prefix to attach to sub-word units that don’t represent a beginning of word.</param>
/// <param name="endOfWordSuffix">The suffix to attach to sub-word units that represent an end of word.</param>
public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null)
/// <param name="fuseUnknownTokens">Indicate whether allowing multiple unknown tokens get fused.</param>
public Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken = null, string? continuingSubwordPrefix = null, string? endOfWordSuffix = null, bool fuseUnknownTokens = false) :
this(vocabStream, mergesStream, unknownToken, continuingSubwordPrefix, endOfWordSuffix, fuseUnknownTokens, disposeStreams: false)
{
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;

(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile);
Vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();

VocabReverse = new();

foreach (KeyValuePair<string, int> kvp in Vocab)
{
VocabReverse.Add(kvp.Value, kvp.Key);
}
}

if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, string? continuingSubwordPrefix, string? endOfWordSuffix, bool fuseUnknownTokens, bool disposeStreams)
{
try
{
unknownToken = unkToken;
}
if (vocabStream is null)
{
throw new ArgumentNullException(nameof(vocabStream));
}

UnknownToken = unknownToken;
FuseUnknownTokens = fuseUnknownTokens;
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;

int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;
(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
Vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();

Merges = new();
for (int i = 0; i < merges.Count; i++)
{
(string a, string b) mergeValues = merges[i];
VocabReverse = new();

if (!Vocab.TryGetValue(mergeValues.a, out int aId))
foreach (KeyValuePair<string, int> kvp in Vocab)
{
throw new InvalidOperationException($"Trying to merge a token {mergeValues.a} which not exist in the vocabulary.");
VocabReverse.Add(kvp.Value, kvp.Key);
}

if (!Vocab.TryGetValue(mergeValues.b, out int bId))
if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
{
throw new InvalidOperationException($"Trying to merge a token {mergeValues.b} which not exist in the vocabulary.");
unknownToken = unkToken;
}

string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}";
if (!Vocab.TryGetValue(newToken, out int newId))
UnknownToken = unknownToken;

int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;

Merges = new();
for (int i = 0; i < merges.Count; i++)
{
throw new InvalidOperationException($"Trying to merge a token {newToken} which not exist in the vocabulary.");
}
(string a, string b) mergeValues = merges[i];

if (!Vocab.TryGetValue(mergeValues.a, out int aId))
{
throw new InvalidOperationException($"Trying to merge a token '{mergeValues.a}' which not exist in the vocabulary.");
}

if (!Vocab.TryGetValue(mergeValues.b, out int bId))
{
throw new InvalidOperationException($"Trying to merge a token '{mergeValues.b}' which not exist in the vocabulary.");
}

Merges.Add(new Pair<int>(aId, bId), (i, newId));
string newToken = $"{mergeValues.a}{mergeValues.b.Substring(prefixLen)}";
if (!Vocab.TryGetValue(newToken, out int newId))
{
throw new InvalidOperationException($"Trying to merge a token '{newToken}' which not exist in the vocabulary.");
}

Merges.Add(new Pair<int>(aId, bId), (i, newId));
}
}
finally
{
if (disposeStreams)
{
vocabStream.Dispose();
mergesStream?.Dispose();
}
}
}

Expand Down Expand Up @@ -195,8 +222,9 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <param name="filterUnsupportedChars">Indicate if want to filter the unsupported characters during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToToken(int id, bool skipSpecialTokens = false)
public override string? IdToToken(int id, bool skipSpecialTokens = false, bool filterUnsupportedChars = true)
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
{
if (VocabReverse.TryGetValue(id, out string? value))
{
Expand All @@ -216,53 +244,10 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
/// </summary>
public override int GetVocabSize() => Vocab.Count;

/// <summary>
/// Gets a trainer object to use in training the model and generate the vocabulary and merges data.
/// </summary>
public override Trainer? GetTrainer() => new BpeTrainer();

/// <summary>
/// Save the model data into the vocabulary and merges files.
/// </summary>
/// <param name="path">The file system path to store the generated files at.</param>
/// <param name="prefix">Optional prefix for the generated file names.</param>
/// <returns>The list of all saved files.</returns>
public override string[] Save(string path, string? prefix = null)
{
// Write vocab.json
string vocabFileNname = prefix is null ? "vocab.json" : $"{prefix}-vocab.json";
string vocabPath = Path.Combine(path, vocabFileNname);
string serialized = JsonSerializer.Serialize(VocabReverse, new JsonSerializerOptions { Converters = { new DictReversingConverter() } });
File.WriteAllText(vocabPath, serialized, System.Text.Encoding.UTF8);

// Write merges.txt
string mergeFileName = prefix is null ? "merges.txt" : $"{prefix}-merges.txt";
string mergePath = Path.Combine(path, mergeFileName);
(Pair<int> pair, int rank)[] pairsArray = new (Pair<int>, int)[Merges.Count];
int i = 0;
foreach (var p in Merges)
{
pairsArray[i++] = (p.Key, p.Value.Item1 /* rank */);
}
Array.Sort(pairsArray, (x, y) => x.rank.CompareTo(y.rank));
using StreamWriter file = new(mergePath, append: false, System.Text.Encoding.UTF8);
file.WriteLine("#version: 0.2 - Trained by `huggingface/tokenizers`");
foreach (var p in pairsArray)
{
file.WriteLine($"{VocabReverse[p.pair.First]} {VocabReverse[p.pair.Second]}");
}

return new string[] { vocabPath, mergePath };
}

/// Read the given files to extract the vocab and merges
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(string vocab, string? merges)
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
Dictionary<string, int>? dic;
using (Stream stream = File.OpenRead(vocab))
{
dic = JsonSerializer.Deserialize<Dictionary<string, int>>(stream) as Dictionary<string, int>;
}
Dictionary<string, int>? dic = JsonSerializer.Deserialize<Dictionary<string, int>>(vocab) as Dictionary<string, int>;

return (dic, ConvertMergesToHashmap(merges));
}
Expand All @@ -287,23 +272,32 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadFile(strin

/// Converts the merges strings (for example from `merges.txt` file) with the format
/// "{pair_a} {pair_b}" into the format expected by the BPE struct
internal static Vec<(string, string)> ConvertMergesToHashmap(string? mergesFile)
internal static Vec<(string, string)> ConvertMergesToHashmap(Stream? mergesStream)
{
if (mergesFile is null)
if (mergesStream is null)
{
return new Vec<(string, string)>();
}

using StreamReader reader = new StreamReader(mergesStream);

Vec<(string, string)> merges = new(1000);

int lineNumber = 0;
foreach (string line in System.IO.File.ReadLines(mergesFile))
while (true)
{
string? line = reader.ReadLine();
if (line is null)
{
break;
}

lineNumber++;
if (line.StartsWith("#version", StringComparison.Ordinal) || line.Length == 0)
{
continue;
}

int index = line.IndexOf(' ');
if (index < 0 || index == line.Length - 1 || line.IndexOf(' ', index + 1) >= 0)
{
Expand Down
Loading