Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tweak CreateByModelNameAsync #7015

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/Microsoft.ML.Tokenizers/Model/Tiktoken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.Tokenizers
Expand Down Expand Up @@ -104,9 +105,11 @@ private Tiktoken(int cacheSize)
/// </summary>
/// <param name="tikTokenBpeFileStream">Stream to the BPE rank file</param>
/// <param name="useAsync">Whether to perform I/O synchronously or asynchronously.</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>Map of byte[] to integer token id</returns>
/// <exception cref="InvalidOperationException"></exception>
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(Stream tikTokenBpeFileStream, bool useAsync)
internal static async ValueTask<(Dictionary<ReadOnlyMemory<byte>, int>, Dictionary<string, int>, IReadOnlyDictionary<int, byte[]>)> LoadTikTokenBpeAsync(
Stream tikTokenBpeFileStream, bool useAsync, CancellationToken cancellationToken = default)
{
var encoder = new Dictionary<ReadOnlyMemory<byte>, int>(ReadOnlyMemoryByteComparer.Instance);
var vocab = new Dictionary<string, int>();
Expand All @@ -119,7 +122,7 @@ private Tiktoken(int cacheSize)
while (true)
{
string? line = useAsync ?
await reader.ReadLineAsync().ConfigureAwait(false) :
await Helpers.ReadLineAsync(reader, cancellationToken).ConfigureAwait(false) :
reader.ReadLine();
if (string.IsNullOrWhiteSpace(line))
{
Expand All @@ -136,10 +139,10 @@ await reader.ReadLineAsync().ConfigureAwait(false) :
throw new FormatException($"Invalid format in the BPE encoder file stream");
}

byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);

if (Helpers.TryParseInt32(line, spaceIndex + 1, out int rank))
{
byte[] tokenBytes = Helpers.FromBase64String(line, 0, spaceIndex);

encoder[tokenBytes] = rank;
decoder[rank] = tokenBytes;

Expand Down Expand Up @@ -214,7 +217,7 @@ public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialTok
// cache miss
if (_vocab.TryGetValue(sequence, out int mappedId))
{
return new List<Token> { new(mappedId, sequence, (0, sequence.Length)) };
return new Token[1] { new(mappedId, sequence, (0, sequence.Length)) };
}

byte[] arrayPoolArray = ArrayPool<byte>.Shared.Rent(Encoding.UTF8.GetMaxByteCount(sequence.Length));
Expand Down
62 changes: 38 additions & 24 deletions src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.IO;
using System.Net.Http;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.Tokenizers
Expand Down Expand Up @@ -346,32 +347,41 @@ private static readonly (string Prefix, ModelEncoding Encoding)[] _modelPrefixTo
/// <param name="modelName">Model name</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the model</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
public static async Task<Tokenizer> CreateByModelNameAsync(
public static Task<Tokenizer> CreateByModelNameAsync(
string modelName,
IReadOnlyDictionary<string, int>? extraSpecialTokens = null,
Normalizer? normalizer = null)
Normalizer? normalizer = null,
CancellationToken cancellationToken = default)
{
ModelEncoding encoder;

if (!_modelToEncoding.TryGetValue(modelName, out encoder))
try
{
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
ModelEncoding encoder;

if (!_modelToEncoding.TryGetValue(modelName, out encoder))
{
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
foreach ((string Prefix, ModelEncoding Encoding) in _modelPrefixToEncoding)
{
encoder = Encoding;
break;
if (modelName.StartsWith(Prefix, StringComparison.OrdinalIgnoreCase))
{
encoder = Encoding;
break;
}
}
}
}

if (encoder == ModelEncoding.None)
if (encoder == ModelEncoding.None)
{
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
}

return CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer, cancellationToken);
}
catch (Exception ex)
{
throw new NotImplementedException($"Doesn't support this model [{modelName}]");
return Task.FromException<Tokenizer>(ex);
}

return await CreateByEncoderNameAsync(encoder, extraSpecialTokens, normalizer).ConfigureAwait(false);
}

private const string Cl100kBaseRegexPattern = @"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
Expand Down Expand Up @@ -402,36 +412,38 @@ public static async Task<Tokenizer> CreateByModelNameAsync(
/// <param name="modelEncoding">Encoder label</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
/// <exception cref="NotImplementedException">Throws if the encoder is not supported</exception>
private static async Task<Tokenizer> CreateByEncoderNameAsync(
private static Task<Tokenizer> CreateByEncoderNameAsync(
ModelEncoding modelEncoding,
IReadOnlyDictionary<string, int>? extraSpecialTokens,
Normalizer? normalizer)
Normalizer? normalizer,
CancellationToken cancellationToken)
{
switch (modelEncoding)
{
case ModelEncoding.Cl100kBase:
var specialTokens = new Dictionary<string, int>
{ { EndOfText, 100257}, { FimPrefix, 100258}, { FimMiddle, 100259}, { FimSuffix, 100260}, { EndOfPrompt, 100276} };
return await CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(Cl100kBaseRegex(), Cl100kBaseVocabUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.P50kBase:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.P50kEdit:
specialTokens = new Dictionary<string, int>
{ { EndOfText, 50256 }, { FimPrefix, 50281 }, { FimMiddle, 50282 }, { FimSuffix, 50283 } };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), P50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.R50kBase:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 } };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), R50RanksUrl, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

case ModelEncoding.GPT2:
specialTokens = new Dictionary<string, int> { { EndOfText, 50256 }, };
return await CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer).ConfigureAwait(false);
return CreateTikTokenTokenizerAsync(P50kBaseRegex(), GPT2Url, specialTokens, extraSpecialTokens, normalizer, cancellationToken);

default:
Debug.Assert(false, $"Unexpected encoder [{modelEncoding}]");
Expand All @@ -449,13 +461,15 @@ private static async Task<Tokenizer> CreateByEncoderNameAsync(
/// <param name="specialTokens">Special tokens mapping. This may be mutated by the method.</param>
/// <param name="extraSpecialTokens">Extra special tokens other than the built-in ones for the encoder</param>
/// <param name="normalizer">To normalize the text before tokenization</param>
/// <param name="cancellationToken"><see cref="CancellationToken"/> used to request cancellation of the operation.</param>
/// <returns>The tokenizer</returns>
private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(
Regex regex,
string mergeableRanksFileUrl,
Dictionary<string, int> specialTokens,
IReadOnlyDictionary<string, int>? extraSpecialTokens,
Normalizer? normalizer)
Normalizer? normalizer,
CancellationToken cancellationToken)
{
if (extraSpecialTokens is not null)
{
Expand All @@ -467,9 +481,9 @@ private static async Task<Tokenizer> CreateTikTokenTokenizerAsync(

if (!_tiktokenCache.TryGetValue(mergeableRanksFileUrl, out (Dictionary<ReadOnlyMemory<byte>, int> encoder, Dictionary<string, int> vocab, IReadOnlyDictionary<int, byte[]> decoder) cache))
{
using (Stream stream = await _httpClient.GetStreamAsync(mergeableRanksFileUrl).ConfigureAwait(false))
using (Stream stream = await Helpers.GetStreamAsync(_httpClient, mergeableRanksFileUrl, cancellationToken).ConfigureAwait(false))
{
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true).ConfigureAwait(false);
cache = await Tiktoken.LoadTikTokenBpeAsync(stream, useAsync: true, cancellationToken).ConfigureAwait(false);
}

_tiktokenCache.TryAdd(mergeableRanksFileUrl, cache);
Expand Down
27 changes: 21 additions & 6 deletions src/Microsoft.ML.Tokenizers/Utils/Helpers.netcoreapp.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Buffers.Text;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Threading.Tasks;
using System.Threading;
using System.Net.Http;

namespace Microsoft.ML.Tokenizers
{
internal static class Helpers
{
public static ValueTask<string?> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken) =>
reader.ReadLineAsync(cancellationToken);

public static Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken) =>
client.GetStreamAsync(url, cancellationToken);

public static byte[] FromBase64String(string base64String, int offset, int length)
{
Span<byte> bytes = stackalloc byte[300];
if (!Convert.TryFromBase64Chars(base64String.AsSpan().Slice(offset, length), bytes, out int bytesWritten))
if (!Base64.IsValid(base64String.AsSpan(offset, length), out int decodedLength))
{
throw new System.FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
throw new FormatException($"Invalid base64 string '{base64String.Substring(offset, length)}'");
}
return bytes.Slice(0, bytesWritten).ToArray();

byte[] bytes = new byte[decodedLength];
bool success = Convert.TryFromBase64Chars(base64String.AsSpan(offset, length), bytes, out int bytesWritten);
Debug.Assert(success);
Debug.Assert(bytes.Length == bytesWritten);
return bytes;
}

internal static bool TryParseInt32(string s, int offset, out int result)
=> int.TryParse(s.AsSpan().Slice(offset), NumberStyles.None, CultureInfo.InvariantCulture, out result);
}
}

19 changes: 18 additions & 1 deletion src/Microsoft.ML.Tokenizers/Utils/Helpers.netstandard.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.IO;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.ML.Tokenizers
{
internal static class Helpers
{
public static ValueTask<string> ReadLineAsync(StreamReader reader, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
return new ValueTask<string>(reader.ReadLineAsync());
}

public static async Task<Stream> GetStreamAsync(HttpClient client, string url, CancellationToken cancellationToken)
{
HttpResponseMessage response = await client.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
return await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
}

public static byte[] FromBase64String(string base64String, int offset, int length) => Convert.FromBase64String(base64String.Substring(offset, length));

// Not support signed number
Expand Down
Loading