Skip to content

Commit b7db4fa

Browse files
authored
Added onnx export support for KeyToValueMappingTransformer (#4455)
1 parent f3e0f6b commit b7db4fa

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

src/Microsoft.ML.Data/Transforms/KeyToValue.cs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
using Microsoft.ML.CommandLine;
1212
using Microsoft.ML.Data;
1313
using Microsoft.ML.Internal.Utilities;
14+
using Microsoft.ML.Model.OnnxConverter;
1415
using Microsoft.ML.Model.Pfa;
1516
using Microsoft.ML.Runtime;
1617
using Microsoft.ML.Transforms;
@@ -152,7 +153,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
152153

153154
private protected override IRowMapper MakeRowMapper(DataViewSchema inputSchema) => new Mapper(this, inputSchema);
154155

155-
private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa
156+
private sealed class Mapper : OneToOneMapperBase, ISaveAsPfa, ISaveAsOnnx
156157
{
157158
private readonly KeyToValueMappingTransformer _parent;
158159
private readonly DataViewType[] _types;
@@ -298,6 +299,8 @@ protected KeyToValueMap(Mapper mapper, PrimitiveDataViewType typeVal, int iinfo)
298299
public abstract Delegate GetMappingGetter(DataViewRow input);
299300

300301
public abstract JToken SavePfa(BoundPfaContext ctx, JToken srcToken);
302+
303+
public abstract bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName);
301304
}
302305

303306
private class KeyToValueMap<TKey, TValue> : KeyToValueMap
@@ -494,8 +497,65 @@ public override JToken SavePfa(BoundPfaContext ctx, JToken srcToken)
494497
}
495498
return PfaUtils.If(PfaUtils.Call("<", srcToken, 0), defaultToken, PfaUtils.Index(cellRef, srcToken));
496499
}
500+
501+
public override bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName)
502+
{
503+
string opType;
504+
505+
// Onnx expects the input keys to be int64s. But the input data can come from an ML.NET node that
506+
// may output a uint32. So cast it here to ensure that the data is treated correctly
507+
opType = "Cast";
508+
var castNodeOutput = ctx.AddIntermediateVariable(TypeOutput, "CastNodeOutput", true);
509+
var castNode = ctx.CreateNode(opType, srcVariableName, castNodeOutput, ctx.GetNodeName(opType), "");
510+
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Int64).ToType();
511+
castNode.AddAttribute("to", t);
512+
513+
opType = "LabelEncoder";
514+
var node = ctx.CreateNode(opType, castNodeOutput, dstVariableName, ctx.GetNodeName(opType));
515+
var keys = Array.ConvertAll<int, long>(Enumerable.Range(1, _values.Length).ToArray(), item => Convert.ToInt64(item));
516+
node.AddAttribute("keys_int64s", keys);
517+
518+
if (TypeOutput == NumberDataViewType.Int64)
519+
{
520+
long[] values = Array.ConvertAll<TValue, long>(_values.GetValues().ToArray(), item => Convert.ToInt64(item));
521+
node.AddAttribute("values_int64s", values);
522+
}
523+
else if (TypeOutput == NumberDataViewType.Single)
524+
{
525+
float[] values = Array.ConvertAll<TValue, float>(_values.GetValues().ToArray(), item => Convert.ToSingle(item));
526+
node.AddAttribute("values_floats", values);
527+
}
528+
else if (TypeOutput == TextDataViewType.Instance)
529+
{
530+
string[] values = Array.ConvertAll<TValue, string>(_values.GetValues().ToArray(), item => Convert.ToString(item));
531+
node.AddAttribute("values_strings", values);
532+
}
533+
else
534+
return false;
535+
536+
return true;
537+
}
497538
}
498539

540+
public bool CanSaveOnnx(OnnxContext ctx) => true;
541+
542+
public void SaveAsOnnx(OnnxContext ctx)
543+
{
544+
for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo)
545+
{
546+
var info = _parent.ColumnPairs[iinfo];
547+
var inputColumnName = info.inputColumnName;
548+
549+
if (!ctx.ContainsColumn(inputColumnName))
550+
continue;
551+
552+
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName, true);
553+
if (!_kvMaps[iinfo].SaveOnnx(ctx, inputColumnName, dstVariableName))
554+
{
555+
ctx.RemoveColumn(inputColumnName, true);
556+
}
557+
}
558+
}
499559
}
500560
}
501561

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,42 @@ public void OptionalColumnOnnxTest()
10401040
Done();
10411041
}
10421042

1043+
[Fact]
1044+
private void KeyToValueOnnxConversionTest()
1045+
{
1046+
var mlContext = new MLContext(seed: 1);
1047+
1048+
string dataPath = GetDataPath("breast-cancer.txt");
1049+
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerMulticlassExample>(dataPath,
1050+
separatorChar: '\t',
1051+
hasHeader: true);
1052+
1053+
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelKey", "Label").
1054+
Append(mlContext.Transforms.Conversion.MapKeyToValue("LabelValue", "LabelKey"));
1055+
1056+
var model = pipeline.Fit(dataView);
1057+
var transformedData = model.Transform(dataView);
1058+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1059+
1060+
var onnxFileName = "KeyToValue.onnx";
1061+
var onnxModelPath = GetOutputPath(onnxFileName);
1062+
1063+
SaveOnnxModel(onnxModel, onnxModelPath, null);
1064+
1065+
if (IsOnnxRuntimeSupported())
1066+
{
1067+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1068+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1069+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1070+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
1071+
var onnxTransformer = onnxEstimator.Fit(dataView);
1072+
var onnxResult = onnxTransformer.Transform(dataView);
1073+
CompareSelectedScalarColumns<ReadOnlyMemory<Char>>(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult);
1074+
}
1075+
1076+
Done();
1077+
}
1078+
10431079
private void CreateDummyExamplesToMakeComplierHappy()
10441080
{
10451081
var dummyExample = new BreastCancerFeatureVector() { Features = null };
@@ -1105,6 +1141,34 @@ private void CompareSelectedVectorColumns<T>(string leftColumnName, string right
11051141
}
11061142
}
11071143

1144+
private void CompareSelectedScalarColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
1145+
{
1146+
var leftColumn = left.Schema[leftColumnName];
1147+
var rightColumn = right.Schema[rightColumnName];
1148+
1149+
using (var expectedCursor = left.GetRowCursor(leftColumn))
1150+
using (var actualCursor = right.GetRowCursor(rightColumn))
1151+
{
1152+
T expected = default;
1153+
VBuffer<T> actual = default;
1154+
var expectedGetter = expectedCursor.GetGetter<T>(leftColumn);
1155+
var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn);
1156+
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
1157+
{
1158+
expectedGetter(ref expected);
1159+
actualGetter(ref actual);
1160+
var actualVal = actual.GetItemOrDefault(0);
1161+
1162+
Assert.Equal(1, actual.Length);
1163+
1164+
if (typeof(T) == typeof(ReadOnlyMemory<Char>))
1165+
Assert.Equal(expected.ToString(), actualVal.ToString());
1166+
else
1167+
Assert.Equal(expected, actualVal);
1168+
}
1169+
}
1170+
}
1171+
11081172
private void CompareSelectedR8VectorColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
11091173
{
11101174
var leftColumn = left.Schema[leftColumnName];

0 commit comments

Comments
 (0)