Skip to content

Add Span support in tokenizer's Model abstraction #7035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 71 additions & 51 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -34,20 +36,21 @@ private set
{
_unknownToken = value;

if (value is null)
if (VocabReverse.TryGetValue(0, out string? v))
{
if (VocabReverse.TryGetValue(0, out string? v))
if (v == value)
{
VocabReverse.Remove(0);
if (_vocab.TryGetValue(v, out int id))
{
_vocab.Remove(v);
}
return;
}

VocabReverse.Remove(0);
_vocab.Remove(new StringSpanOrdinalKey(v));
}
else


if (value is not null)
{
_vocab[value] = 0;
_vocab[new StringSpanOrdinalKey(value)] = 0;
VocabReverse[0] = value;
}
}
Expand All @@ -68,7 +71,6 @@ private set
/// </summary>
public bool FuseUnknownTokens { get; }


/// <summary>
/// Construct a new Bpe model object to use for text encoding.
/// </summary>
Expand Down Expand Up @@ -111,23 +113,19 @@ private Bpe(Stream vocabStream, Stream? mergesStream, string? unknownToken, stri
ContinuingSubwordPrefix = continuingSubwordPrefix;
EndOfWordSuffix = endOfWordSuffix;

(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();
(Dictionary<StringSpanOrdinalKey, int>? vocab1, Vec<(string, string)> merges) = ReadModelData(vocabStream, mergesStream);
_vocab = vocab1 ?? new Dictionary<StringSpanOrdinalKey, int>();
Cache = new StringSpanOrdinalKeyCache<Word>();

VocabReverse = new();

foreach (KeyValuePair<string, int> kvp in Vocab)
foreach (KeyValuePair<StringSpanOrdinalKey, int> kvp in _vocab)
{
VocabReverse.Add(kvp.Value, kvp.Key);
VocabReverse.Add(kvp.Value, kvp.Key.Data!);
}

if (unknownToken is null && VocabReverse.TryGetValue(0, out string? unkToken))
{
unknownToken = unkToken;
}

UnknownToken = unknownToken;
UnknownToken = unknownToken ?? (VocabReverse.TryGetValue(0, out string? unkToken) ? unkToken : null);

int prefixLen = ContinuingSubwordPrefix is null ? 0 : ContinuingSubwordPrefix.Length;

Expand Down Expand Up @@ -197,31 +195,23 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
/// <param name="text">The text to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated encoded Ids.</param>
public override void EncodeToIds(string text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);
public override void EncodeToIds(ReadOnlySpan<char> text, bool isSpecialToken, IList<int> accumulatedIds) => EncodeToIdsWithCache(text, accumulatedIds);

/// <summary>
/// Get the number of tokens that the input text will be encoded to.
/// </summary>
/// <param name="text">The text to encode.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input text will be encoded to.</returns>
public override int CountTokens(string text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);
public override int CountTokens(ReadOnlySpan<char> text, bool isSpecialToken) => EncodeToIdsWithCache(text, null);

/// <summary>
/// Map the token to encoded Id.
/// </summary>
/// <param name="token">The token to map to the Id.</param>
/// <param name="considerSpecialTokens">Indicate if want to consider the special tokens during the encoding.</param>
/// <returns>The mapped Id of the token.</returns>
public override int? MapTokenToId(string token, bool considerSpecialTokens = true)
{
if (_vocab.TryGetValue(token, out int value))
{
return value;
}

return null;
}
public override int? MapTokenToId(ReadOnlySpan<char> token, bool considerSpecialTokens = true) => _vocab.TryGetValue(token, out int value) ? value : null;

/// <summary>
/// Map the encoded Id to the token.
Expand All @@ -242,24 +232,27 @@ public override IReadOnlyList<Token> Encode(string text, bool isSpecialToken = f
/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
public IReadOnlyDictionary<string, int> Vocab => _vocab;
public IReadOnlyDictionary<string, int> Vocab => _vocabOriginal ??= _vocab.ToDictionary(kvp => kvp.Key.Data!, kvp => kvp.Value);

/// Read the given files to extract the vocab and merges
internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
internal static (Dictionary<StringSpanOrdinalKey, int>?, Vec<(string, string)>) ReadModelData(Stream vocab, Stream? merges)
{
Dictionary<string, int>? dic = JsonSerializer.Deserialize<Dictionary<string, int>>(vocab) as Dictionary<string, int>;
JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } };
Dictionary<StringSpanOrdinalKey, int>? dic = JsonSerializer.Deserialize<Dictionary<StringSpanOrdinalKey, int>>(vocab, options) as Dictionary<StringSpanOrdinalKey, int>;

return (dic, ConvertMergesToHashmap(merges));
}

/// The vocabulary assigns a number to each token.
private readonly Dictionary<string, int> _vocab;
private readonly Dictionary<StringSpanOrdinalKey, int> _vocab;

private Dictionary<string, int>? _vocabOriginal;

/// Contains the mapping between Pairs and their (rank, newId).
internal Dictionary<Pair<int>, (int, int)> Merges { get; }

/// Contains the cache for optimizing the encoding step.
internal Cache<string, Word>? Cache { get; }
internal StringSpanOrdinalKeyCache<Word>? Cache { get; }

internal static readonly int DefaultCacheCapacity = 10_000;

Expand Down Expand Up @@ -309,9 +302,6 @@ internal static (Dictionary<string, int>?, Vec<(string, string)>) ReadModelData(
return merges;
}

/// Reset the cache.
internal void ClearCache() => Cache?.Clear();

private readonly Dictionary<char, string> _charToString = new Dictionary<char, string>();

[MethodImpl(MethodImplOptions.AggressiveInlining)]
Expand All @@ -327,38 +317,68 @@ internal string CharToString(char c)
return s;
}

internal Word MergeWord(string w)
internal Word MergeWord(ReadOnlySpan<char> w)
{
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;

Span<char> buffer = stackalloc char[256];
scoped ReadOnlySpan<char> s;

while (i < w.Length)
{
int length;
string s;

if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
s = w.Substring(i, length);
s = w.Slice(i, 2);
}
else
{
length = 1;
s = CharToString(w[i]);
s = w.Slice(i, 1);
}

// Add the `continuing_subword_prefix` if relevant
if (i > 0 && ContinuingSubwordPrefix is not null)
{
s = $"{ContinuingSubwordPrefix}{s}";
if (ContinuingSubwordPrefix.Length + s.Length <= buffer.Length)
{
ContinuingSubwordPrefix.AsSpan().CopyTo(buffer);
s.CopyTo(buffer.Slice(ContinuingSubwordPrefix.Length));
s = buffer.Slice(0, ContinuingSubwordPrefix.Length + s.Length);
}
else
{
#if NETCOREAPP
s = $"{ContinuingSubwordPrefix}{s}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{ContinuingSubwordPrefix}{s1}".AsSpan();
#endif
}
}

// Add the `end_of_word_suffix` if relevant
if (i + length >= w.Length && EndOfWordSuffix is not null)
{
s = $"{s}{EndOfWordSuffix}";
if (s.Length + EndOfWordSuffix.Length <= buffer.Length)
{
s.CopyTo(buffer);
EndOfWordSuffix.AsSpan().CopyTo(buffer.Slice(s.Length));
s = buffer.Slice(0, s.Length + EndOfWordSuffix.Length);
}
else
{
#if NETCOREAPP
s = $"{s}{EndOfWordSuffix}".AsSpan();
#else
string s1 = s.Length == 1 ? CharToString(s[0]) : s.ToString();
s = $"{s1}{EndOfWordSuffix}".AsSpan();
#endif
}
}

if (_vocab.TryGetValue(s, out int id))
Expand Down Expand Up @@ -419,17 +439,17 @@ internal List<Token> EncodeWithCache(string text)
Word word;
if (Cache is not null)
{
if (Cache.TryGet(text, out word))
if (Cache.TryGetValue(text, out word))
{
return WordToTokens(ref word);
}

word = MergeWord(text);
word = MergeWord(text.AsSpan());
Cache.Set(text, word);
}
else
{
word = MergeWord(text);
word = MergeWord(text.AsSpan());
}

return WordToTokens(ref word);
Expand All @@ -445,19 +465,19 @@ internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
return word.SymbolsCount;
}

internal int EncodeToIdsWithCache(string text, IList<int>? accumulatedIds)
internal int EncodeToIdsWithCache(ReadOnlySpan<char> text, IList<int>? accumulatedIds)
{
Word word;

if (Cache is not null)
{
if (Cache.TryGet(text, out Word hit))
if (Cache.TryGetValue(text, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}

word = MergeWord(text);
Cache.Set(text, word);
Cache.Set(text.ToString(), word);
}
else
{
Expand Down
Loading