Skip to content

Commit

Permalink
Tweak CreateByModelNameAsync (#7015)
Browse files Browse the repository at this point in the history
- Add a CancellationToken to CreateByModelNameAsync, allowing the download and parsing to be canceled.
- Use ReadLineAsync(cancellationToken), which not only allows it to be canceled, but avoids ~100K task allocations
- Fix Helpers.FromBase64String to support lines longer than 300 chars
  • Loading branch information
stephentoub authored Feb 20, 2024
1 parent 3282f44 commit 2c9f775
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 36 deletions.
13 changes: 8 additions & 5 deletions src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.Tokenizers
Expand Down Expand Up @@ -111,9 +112,11 @@ private Tiktoken(int cacheSize)
/// </summary>
/// <param name="tikTokenBpeFileStream">Stream to the BPE rank file</param>
/// <param name="useAsync">Whether to perform I/O synchronously or asynchronously.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Map of byte[] to integer token id</returns>
/// <exception cref="InvalidOperationException"></exception>
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync)
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(
Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary<string, int>();
Expand All @@ -126,7 +129,7 @@ private Tiktoken(int cacheSize)
while (true)
{
string? line = useAsync ?
await reader.ReadLineAsync().ConfigureAwait(false) :
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
reader.ReadLine();
if (string.IsNullOrWhiteSpace(line))
{
Expand All @@ -143,10 +146,10 @@ await reader.ReadLineAsync().ConfigureAwait(false) :
throw new FormatException($"Invalid format in the BPE encoder file stream");
}

byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);

if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank))
{
byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);

encoder[tokenBytes] = rank;
decoder[rank] = tokenBytes;

Expand Down Expand Up @@ -221,7 +224,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
// cache miss
if (_vocab.TryGetValue(sequence, out int mappedId))
{
return new List<Token> { new(mappedId, sequence, (0, sequence.Length)) };
return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) };
}

byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));
Expand Down
62 changes: 38 additions & 24 deletions src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.IO;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.Tokenizers
Expand Down Expand Up @@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
/// <param name="modelName">Model name</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
public static async Task<Tokenizer> CreateByModelNameAsync(
public static Task<Tokenizer> CreateByModelNameAsync(
string modelName,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
Normalizer? normalizer = null)
Normalizer? normalizer = null,
CancellationToken cancellationToken = default)
{
ModelEncoding encoder;

if (!_modelToEncoding.TryGetValue(modelName, out encoder))
try
{
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
ModelEncoding encoder;

if (!_modelToEncoding.TryGetValue(modelName, out encoder))
{
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
{
encoder = Encoding;
break;
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
{
encoder = Encoding;
break;
}
}
}
}

if (encoder == ModelEncoding.None)
if (encoder == ModelEncoding.None)
{
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
}

return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken);
}
catch (Exception ex)
{
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
return Task.FromException<Tokenizer>(ex);
}

return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false);
}

private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
Expand Down Expand Up @@ -402,36 +412,38 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
/// <param name="modelEncoding">Encoder label</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
/// <exception cref="NotImplementedException">Throws if the encoder is not supported</exception>
private static async Task<Tokenizer> CreateByEncoderNameAsync(
private static Task<Tokenizer> CreateByEncoderNameAsync(
ModelEncoding modelEncoding,
IReadOnlyDictionary<string, int>? extraSpecialTokens,
Normalizer? normalizer)
Normalizer? normalizer,
CancellationToken cancellationToken)
{
switch (modelEncoding)
{
case ModelEncoding.Cl100kBase:
var specialTokens = new Dictionary<string, int>
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.P50kBase:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.P50kEdit:
specialTokens = new Dictionary<string, int>
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.R50kBase:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.GPT2:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 }, };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

default:
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
Expand All @@ -449,13 +461,15 @@ private static async Task<Tokenizer> CreateByEncoderNameAsync(
/// <param name="specialTokens">Special tokens mapping. This may be mutated by the method.</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
Regex regex,
string mergeableRanksFileUrl,
Dictionary<string, int> specialTokens,
IReadOnlyDictionary<string, int>? extraSpecialTokens,
Normalizer? normalizer)
Normalizer? normalizer,
CancellationToken cancellationToken)
{
if (extraSpecialTokens is not null)
{
Expand All @@ -467,9 +481,9 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(

if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, IReadOnlyDictionary<int, byte[]> decoder) cache))
{
using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false))
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
{
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false);
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
}

_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
Expand Down
27 changes: 21 additions & 6 deletions src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers.Text;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Threading.Tasks;
using System.Threading;
using System.Net.Http;

namespace Microsoft.ML.Tokenizers
{
internal static class Helpers
{
public static ValueTask<string?> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) =>
reader.ReadLineAsync(cancellationToken);

public static Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) =>
client.GetStreamAsync(url, cancellationToken);

public static byte[] FromBase64String(string base64String, int offset, int length)
{
Span<byte> bytes = stackalloc byte[300];
if (!Convert.TryFromBase64Chars(base64String.AsSpan().Slice(offset, length), bytes, out int bytesWritten))
if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength))
{
throw new System.FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
throw new FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
}
return bytes.Slice(0, bytesWritten).ToArray();

byte[] bytes = new byte[decodedLength];
bool success = Convert.TryFromBase64Chars(base64String.AsSpan(offset, length), bytes, out int bytesWritten);
Debug.Assert(success);
Debug.Assert(bytes.Length == bytesWritten);
return bytes;
}

internal static bool TryParseInt32(string s, int offset, out int result)
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
}
}

19 changes: 18 additions & 1 deletion src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.Tokenizers
{
internal static class Helpers
{
public static ValueTask<string> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
return new ValueTask<string>(reader.ReadLineAsync());
}

public static async Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken)
{
HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
return await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
}

public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length));

// Not support signed number
Expand Down

0 comments on commit 2c9f775

Please sign in to comment.