Skip to content

Commit

Permalink
Cleanup normalization support
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekgh committed Feb 23, 2024
1 parent 62334c6 commit d48b32d
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ public LowerCaseNormalizer() { }
/// </summary>
/// <param name="original">The original string to normalize to lowercase form.</param>
/// <returns>The lower-cased normalized string.</returns>
public override NormalizedString Normalize(string original) => new NormalizedString(original, original.ToLowerInvariant(), normalizedToOriginalMapping: null, isOneToOneMapping: true);
public override string Normalize(string original) => original.ToLowerInvariant();
}
}
64 changes: 0 additions & 64 deletions src/Microsoft.ML.Tokenizers/Normalizer/NormalizedString.cs

This file was deleted.

4 changes: 2 additions & 2 deletions src/Microsoft.ML.Tokenizers/Normalizer/Normalizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public abstract class Normalizer
/// Process the original string to modify it and obtain a normalized string.
/// </summary>
/// <param name="original">The original string to normalize.</param>
/// <returns>The normalized string along with the mapping to the original string.</returns>
public abstract NormalizedString Normalize(string original);
/// <returns>The normalized string.</returns>
public abstract string Normalize(string original);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ public UpperCaseNormalizer() { }
/// </summary>
/// <param name="original">The original string to normalize to uppercase form.</param>
/// <returns>The upper-cased normalized string.</returns>
public override NormalizedString Normalize(string original) => new NormalizedString(original, original.ToUpperInvariant(), normalizedToOriginalMapping: null, isOneToOneMapping: true);
public override string Normalize(string original) => original.ToUpperInvariant();
}
}
52 changes: 8 additions & 44 deletions src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,56 +67,20 @@ public EncodingResult Encode(string text, bool considerSpecialTokens = true)
throw new ArgumentNullException(nameof(text));
}

string normalized;
NormalizedString normalizedString = default;

string normalized = Normalizer is null ? text : Normalizer.Normalize(text);
bool offsetsMappedToOriginal = true;
if (Normalizer is not null)
{
normalizedString = Normalizer.Normalize(text);
normalized = normalizedString.Normalized;

offsetsMappedToOriginal = normalizedString.CanMapToOriginal;
}
else
{
normalized = text;
}

EncodingResult encoding = new(text, normalized, PreTokenizer.PreTokenize(normalized, considerSpecialTokens), offsetsMappedToOriginal);

if (Normalizer is null || !normalizedString.CanMapToOriginal || normalizedString.IsOneToOneMapping)
foreach (Split split in encoding.Splits)
{
// Optimize the case we don't have to map the offsets.
foreach (Split split in encoding.Splits)
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString, split.IsSpecialToken);
foreach (Token token in tokens)
{
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString, split.IsSpecialToken);
foreach (Token token in tokens)
{
token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length);
}

encoding.AddTokens(tokens);
token.Offset = (token.Offset.Index + split.Offset.Index, token.Offset.Length);
}
}
else
{
Debug.Assert(normalizedString.NormalizedToOriginalMapping is not null);

foreach (Split split in encoding.Splits)
{
IReadOnlyList<Token> tokens = Model.Encode(split.TokenString, split.IsSpecialToken);
foreach (Token token in tokens)
{
int index = normalizedString.NormalizedToOriginalMapping![token.Offset.Index + split.Offset.Index];

Debug.Assert(index >= 0);

token.Offset = (index, token.Offset.Length);
}

encoding.AddTokens(tokens);
}
encoding.AddTokens(tokens);
}

return encoding;
Expand All @@ -135,7 +99,7 @@ public IReadOnlyList<int> EncodeToIds(string text, bool considerSpecialTokens =
throw new ArgumentNullException(nameof(text));
}

string normalized = Normalizer is not null ? Normalizer.Normalize(text).Normalized : text;
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;
List<int> idsList = new();

foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
Expand All @@ -161,7 +125,7 @@ public int CountTokens(string text, bool considerSpecialTokens = true)
throw new ArgumentNullException(nameof(text));
}

string normalized = Normalizer is not null ? Normalizer.Normalize(text).Normalized : text;
string normalized = Normalizer is not null ? Normalizer.Normalize(text) : text;

int idsCount = 0;
foreach (Split split in PreTokenizer.PreTokenize(normalized, considerSpecialTokens))
Expand Down
41 changes: 11 additions & 30 deletions test/Microsoft.ML.Tokenizers.Tests/NormalizerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,78 +22,59 @@ public static IEnumerable<object?[]> NormalizerData
new LowerCaseNormalizer(),
"How Are You Doing?",
"how are you doing?",
true, // IsOneToOneMapping
true, // CanMapToOriginal
null, // NormalizedToOriginalMapping
};

yield return new object?[]
{
new UpperCaseNormalizer(),
"How Are You Doing?",
"HOW ARE YOU DOING?",
true, // IsOneToOneMapping
true, // CanMapToOriginal
null, // NormalizedToOriginalMapping
};

yield return new object?[]
{
new RemoveQuotesNormalizer(),
"This is already normalized string",
"This is already normalized string",
true, // IsOneToOneMapping
true, // CanMapToOriginal
null, // NormalizedToOriginalMapping
};

yield return new object?[]
{
new RemoveQuotesNormalizer(),
"String \"to\" normalize",
"String to normalize",
false, // IsOneToOneMapping
true, // CanMapToOriginal
new int[] { 0, 1, 2, 3, 4, 5, 6, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20 }, // NormalizedToOriginalMapping
};

yield return new object?[]
{
new UnicodeNormalizer(NormalizationForm.FormKD),
"\uFB01", // Composed form of the character 'fi' one character
"fi", // normalized in 2 characters 'f' and 'i'
false, // IsOneToOneMapping
false, // CanMapToOriginal
null, // NormalizedToOriginalMapping
};
}
}

[Theory]
[MemberData(nameof(NormalizerData))]
public void TestNormalizer(Normalizer normalizer, string sentence, string normalized, bool isOneToOneMapping, bool canMapToOriginal, int[] normalizedToOriginalMapping)
public void TestNormalizer(Normalizer normalizer, string text, string normalized)
{
NormalizedString ns = normalizer.Normalize(sentence);
Assert.Equal(normalized, ns.Normalized);
Assert.Equal(isOneToOneMapping, ns.IsOneToOneMapping);
Assert.Equal(canMapToOriginal, ns.CanMapToOriginal);
Assert.Equal(normalizedToOriginalMapping, ns.NormalizedToOriginalMapping);
string normalizedText = normalizer.Normalize(text);
Assert.Equal(normalized, normalizedText);

Tokenizer tokenizer = new Tokenizer(BpeTests.CreateEmptyBpe(), WhiteSpace.Instance, normalizer);
EncodingResult encoding = tokenizer.Encode(sentence);
Assert.Equal(canMapToOriginal, encoding.OffsetsMappedToOriginalString);
Assert.Equal(sentence, encoding.OriginalString);
EncodingResult encoding = tokenizer.Encode(text);
Assert.Equal(text, encoding.OriginalString);
Assert.Equal(normalized, encoding.NormalizedString);
}

public class RemoveQuotesNormalizer : Normalizer
{
public override NormalizedString Normalize(string original)
public override string Normalize(string original)
{
int index = original.IndexOf('"');
if (index <= 0)
{
return new NormalizedString(original, original, null, true);
return original;
}

StringBuilder sb = new StringBuilder(original.Length);
Expand Down Expand Up @@ -128,7 +109,7 @@ public override NormalizedString Normalize(string original)
}
} while (true);

return new NormalizedString(original, sb.ToString(), mapping.ToArray(), false);
return sb.ToString();
}
}

Expand All @@ -140,14 +121,14 @@ public UnicodeNormalizer(NormalizationForm form)
_normalizationForm = form;
}

public override NormalizedString Normalize(string original)
public override string Normalize(string original)
{
if (string.IsNullOrEmpty(original))
{
return new NormalizedString(original, "", null, true);
return string.Empty;
}

return new NormalizedString(original, original.Normalize(_normalizationForm), null, false);
return original.Normalize(_normalizationForm);
}
}
}
Expand Down

0 comments on commit d48b32d

Please sign in to comment.