Skip to content

TextNormalizing export to Onnx #4781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/Microsoft.ML.Transforms/Text/TextNormalizing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms.Text;

Expand Down Expand Up @@ -194,7 +195,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Dat

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);

private sealed class Mapper : OneToOneMapperBase
private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx
{
private readonly DataViewType[] _types;
private readonly TextNormalizingTransformer _parent;
Expand All @@ -212,6 +213,44 @@ public Mapper(TextNormalizingTransformer parent, DataViewSchema inputSchema)
}
}

public bool CanSaveOnnx(OnnxContext ctx) => (_parent._keepDiacritics && _parent._keepNumbers && _parent._keepPunctuations);

public void SaveAsOnnx(OnnxContext ctx)
{
Host.CheckValue(ctx, nameof(ctx));
for (int iinfo = 0; iinfo < _types.Length; ++iinfo)
{
string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
if (!ctx.ContainsColumn(inputColumnName))
continue;

string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
string srcVariableName = ctx.GetVariableName(inputColumnName);
string dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], outputColumnName, true);
SaveAsOnnxCore(ctx, srcVariableName, dstVariableName);
}
}

private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
// StringNormalizer only takes input of shapes [C] or [1,C],
// so the input is squeezed to support inferred shapes ( e.g. [-1,C] ).
var opType = "Squeeze";
var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);
var node = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });

opType = "StringNormalizer";
var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true);
node = ctx.CreateNode(opType, squeezeOutput, normalizerOutput, ctx.GetNodeName(opType), "");
var isCaseChange = (_parent._caseMode == TextNormalizingEstimator.CaseMode.Lower) ? "LOWER" :
(_parent._caseMode == TextNormalizingEstimator.CaseMode.Upper) ? "UPPER" : "NONE";
node.AddAttribute("case_change_action", isCaseChange);

opType = "Unsqueeze";
node = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
node.AddAttribute("axes", new long[] { 0 });
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If its 'None' there is no need for the 'StringNormalizer' node at all

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, in the case that we are able to add the other options in the future, it would make sense to keep the option to pass "None". What do you think?

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var result = new DataViewSchema.DetachedColumn[_parent.ColumnPairs.Length];
Expand Down
36 changes: 36 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,42 @@ public void PlattCalibratorOnnxConversionTest2()
Done();
}

[Fact]
public void TextNormalizingOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
var dataPath = GetDataPath("wikipedia-detox-250-line-test.tsv");
var dataView = ML.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("label", DataKind.Boolean, 0),
new TextLoader.Column("text", DataKind.String, 1)
}, hasHeader: true);
var pipeline = new TextNormalizingEstimator(mlContext, keepDiacritics: true, columns: new[] { ("NormText", "text") }).Append(
new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.Upper, columns: new[] { ("UpperText", "text") })).Append(
new TextNormalizingEstimator(mlContext, keepDiacritics: true, caseMode: TextNormalizingEstimator.CaseMode.None, columns: new[] { ("OriginalText", "text") }));
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
// Skipping test in Linux platforms temporarily
if (IsOnnxRuntimeSupported() && !RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
var onnxFileName = $"TextNormalizing.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedColumns<ReadOnlyMemory<char>>(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); //compare NormText
CompareSelectedColumns<ReadOnlyMemory<char>>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); //compare UpperText
CompareSelectedColumns<ReadOnlyMemory<char>>(transformedData.Schema[4].Name, outputNames[4], transformedData, onnxResult); //compare OriginalText
}
Done();
}

private class DataPoint
{
[VectorType(3)]
Expand Down