Skip to content

Commit

Permalink
Added onnx export support for KeyToValueMappingTransformer (dotnet#4455)
Browse files Browse the repository at this point in the history
  • Loading branch information
harishsk authored Nov 14, 2019
1 parent f3e0f6b commit b7db4fa
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
62 changes: 61 additions & 1 deletion src/Microsoft.ML.Data/Transforms/KeyToValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
using Microsoft.ML.Internal.Utilities;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Model.Pfa;
using Microsoft.ML.Runtime;
using Microsoft.ML.Transforms;
Expand Down Expand Up @@ -152,7 +153,7 @@ private protected override void SaveModel(ModelSaveContext ctx)

private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);

private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa
private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa, ISaveAsOnnx
{
private readonly KeyToValueMappingTransformer _parent;
private readonly DataViewType[] _types;
Expand Down Expand Up @@ -298,6 +299,8 @@ protected KeyToValueMap(Mapper mapper, PrimitiveDataViewType typeVal, int iinfo)
public abstract Delegate GetMappingGetter(DataViewRow input);

public abstract JToken SavePfa(BoundPfaContext ctx, JToken srcToken);

public abstract bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName);
}

private class KeyToValueMap<TKey, TValue> : KeyToValueMap
Expand Down Expand Up @@ -494,8 +497,65 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken)
}
return PfaUtils.If(PfaUtils.Call("<", srcToken, 0), defaultToken, PfaUtils.Index(cellRef, srcToken));
}

public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName)
{
string opType;

// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
// may output a uint32. So cast it here to ensure that the data is treated correctly
opType = "Cast";
var castNodeOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true);
var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
castNode.AddAttribute("to", t);

opType = "LabelEncoder";
var node = ctx.CreateNode(opType, castNodeOutput, dstVariableName, ctx.GetNodeName(opType));
var keys = Array.ConvertAll<int, long>(Enumerable.Range(1, _values.Length).ToArray(), item => Convert.ToInt64(item));
node.AddAttribute("keys_int64s", keys);

if (TypeOutput == NumberDataViewType.Int64)
{
long[] values = Array.ConvertAll<TValue, long>(_values.GetValues().ToArray(), item => Convert.ToInt64(item));
node.AddAttribute("values_int64s", values);
}
else if (TypeOutput == NumberDataViewType.Single)
{
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
node.AddAttribute("values_floats", values);
}
else if (TypeOutput == TextDataViewType.Instance)
{
string[] values = Array.ConvertAll<TValue, string>(_values.GetValues().ToArray(), item => Convert.ToString(item));
node.AddAttribute("values_strings", values);
}
else
return false;

return true;
}
}

public bool CanSaveOnnx(OnnxContext ctx) => true;

public void SaveAsOnnx(OnnxContext ctx)
{
for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo)
{
var info = _parent.ColumnPairs[iinfo];
var inputColumnName = info.inputColumnName;

if (!ctx.ContainsColumn(inputColumnName))
continue;

var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName, true);
if (!_kvMaps[iinfo].SaveOnnx(ctx, inputColumnName, dstVariableName))
{
ctx.RemoveColumn(inputColumnName, true);
}
}
}
}
}

Expand Down
64 changes: 64 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,42 @@ public void OptionalColumnOnnxTest()
Done();
}

[Fact]
private void KeyToValueOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);

string dataPath = GetDataPath("breast-cancer.txt");
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerMulticlassExample>(dataPath,
separatorChar: '\t',
hasHeader: true);

var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelKey", "Label").
Append(mlContext.Transforms.Conversion.MapKeyToValue("LabelValue", "LabelKey"));

var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = "KeyToValue.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

SaveOnnxModel(onnxModel, onnxModelPath, null);

if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedScalarColumns<ReadOnlyMemory<Char>>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult);
}

Done();
}

private void CreateDummyExamplesToMakeComplierHappy()
{
var dummyExample = new BreastCancerFeatureVector() { Features = null };
Expand Down Expand Up @@ -1105,6 +1141,34 @@ private void CompareSelectedVectorColumns<T>(string leftColumnName, string right
}
}

private void CompareSelectedScalarColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];

using (var expectedCursor = left.GetRowCursor(leftColumn))
using (var actualCursor = right.GetRowCursor(rightColumn))
{
T expected = default;
VBuffer<T> actual = default;
var expectedGetter = expectedCursor.GetGetter<T>(leftColumn);
var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn);
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
{
expectedGetter(ref expected);
actualGetter(ref actual);
var actualVal = actual.GetItemOrDefault(0);

Assert.Equal(1, actual.Length);

if (typeof(T) == typeof(ReadOnlyMemory<Char>))
Assert.Equal(expected.ToString(), actualVal.ToString());
else
Assert.Equal(expected, actualVal);
}
}
}

private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
{
var leftColumn = left.Schema[leftColumnName];
Expand Down

0 comments on commit b7db4fa

Please sign in to comment.