Skip to content

Commit 275f4c2

Browse files
authored
Fix for KeytoValue transformer (#4866)
* fix for keytovalue * fix for keytovalue * resolving comments
1 parent b34c3b6 commit 275f4c2

File tree

2 files changed

+49
-41
lines changed

2 files changed

+49
-41
lines changed

src/Microsoft.ML.Data/Transforms/KeyToValue.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,9 @@ public void SaveAsOnnx(OnnxContext ctx)
568568

569569
if (!ctx.ContainsColumn(inputColumnName))
570570
continue;
571-
571+
string srcVariableName = ctx.GetVariableName(inputColumnName);
572572
var dstVariableName = ctx.AddIntermediateVariable(_types[iinfo], info.outputColumnName, true);
573-
if (!_kvMaps[iinfo].SaveOnnx(ctx, inputColumnName, dstVariableName))
573+
if (!_kvMaps[iinfo].SaveOnnx(ctx, srcVariableName, dstVariableName))
574574
{
575575
ctx.RemoveColumn(inputColumnName, true);
576576
}

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Runtime.InteropServices;
1010
using System.Text.RegularExpressions;
1111
using Google.Protobuf;
12+
using Google.Protobuf.WellKnownTypes;
1213
using Microsoft.ML.Data;
1314
using Microsoft.ML.EntryPoints;
1415
using Microsoft.ML.Model.OnnxConverter;
@@ -1150,7 +1151,7 @@ public void IndicateMissingValuesOnnxConversionTest()
11501151
[InlineData(DataKind.Int64)]
11511152
[InlineData(DataKind.Double)]
11521153
[InlineData(DataKind.String)]
1153-
public void ValueToKeyandKeyToValueMappingOnnxConversionTest(DataKind valueType)
1154+
public void ValueToKeyMappingOnnxConversionTest(DataKind valueType)
11541155
{
11551156
var mlContext = new MLContext(seed: 1);
11561157
string filePath = GetDataPath("type-conversion.txt");
@@ -1160,9 +1161,8 @@ public void ValueToKeyandKeyToValueMappingOnnxConversionTest(DataKind valueType)
11601161
new TextLoader.Column("Value", valueType, 0, 0)
11611162
};
11621163
var dataView = mlContext.Data.LoadFromTextFile(filePath, columns);
1164+
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Key", "Value");
11631165

1164-
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("Key", "Value").
1165-
Append(mlContext.Transforms.Conversion.MapKeyToValue("ValueOutput", "Key"));
11661166
var model = pipeline.Fit(dataView);
11671167
var mlnetResult = model.Transform(dataView);
11681168

@@ -1176,13 +1176,55 @@ public void ValueToKeyandKeyToValueMappingOnnxConversionTest(DataKind valueType)
11761176
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
11771177
var onnxTransformer = onnxEstimator.Fit(dataView);
11781178
var onnxResult = onnxTransformer.Transform(dataView);
1179-
1180-
CompareResults("ValueOutput", "ValueOutput", mlnetResult, onnxResult);
11811179
CompareSelectedColumns<uint>("Key", "Key", mlnetResult, onnxResult);
11821180
}
11831181
Done();
11841182
}
11851183

1184+
[Theory]
1185+
[InlineData(DataKind.Single)]
1186+
[InlineData(DataKind.Int64)]
1187+
[InlineData(DataKind.Double)]
1188+
[InlineData(DataKind.String)]
1189+
public void KeyToValueMappingOnnxConversionTest(DataKind valueType)
1190+
{
1191+
var mlContext = new MLContext(seed: 1);
1192+
string filePath = GetDataPath("type-conversion.txt");
1193+
1194+
TextLoader.Column[] columns = new[]
1195+
{
1196+
new TextLoader.Column("Value", valueType, 0, 0)
1197+
};
1198+
var dataView = mlContext.Data.LoadFromTextFile(filePath, columns);
1199+
IEstimator<ITransformer>[] pipelines =
1200+
{
1201+
mlContext.Transforms.Conversion.MapValueToKey("Key", "Value").
1202+
Append(mlContext.Transforms.Conversion.MapKeyToValue("Value", "Key")),
1203+
1204+
mlContext.Transforms.Conversion.MapValueToKey("Value").
1205+
Append(mlContext.Transforms.Conversion.MapKeyToValue("Value"))
1206+
};
1207+
for (int i = 0; i < pipelines.Length; i++)
1208+
{
1209+
var model = pipelines[i].Fit(dataView);
1210+
var mlnetResult = model.Transform(dataView);
1211+
1212+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1213+
var onnxFileName = "KeyToValue.onnx";
1214+
var onnxModelPath = GetOutputPath(onnxFileName);
1215+
SaveOnnxModel(onnxModel, onnxModelPath, null);
1216+
1217+
if (IsOnnxRuntimeSupported())
1218+
{
1219+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
1220+
var onnxTransformer = onnxEstimator.Fit(dataView);
1221+
var onnxResult = onnxTransformer.Transform(dataView);
1222+
CompareResults("Value", "Value", mlnetResult, onnxResult);
1223+
}
1224+
}
1225+
Done();
1226+
}
1227+
11861228
private class TextData
11871229
{
11881230
public string Text { get; set; }
@@ -1355,40 +1397,6 @@ public void OptionalColumnOnnxTest(DataKind dataKind)
13551397
Done();
13561398
}
13571399

1358-
[Fact]
1359-
public void KeyToValueOnnxConversionTest()
1360-
{
1361-
var mlContext = new MLContext(seed: 1);
1362-
1363-
string dataPath = GetDataPath("breast-cancer.txt");
1364-
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerMulticlassExample>(dataPath,
1365-
separatorChar: '\t',
1366-
hasHeader: true);
1367-
1368-
var pipeline = mlContext.Transforms.Conversion.MapValueToKey("LabelKey", "Label").
1369-
Append(mlContext.Transforms.Conversion.MapKeyToValue("LabelValue", "LabelKey"));
1370-
1371-
var model = pipeline.Fit(dataView);
1372-
var transformedData = model.Transform(dataView);
1373-
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1374-
1375-
var onnxFileName = "KeyToValue.onnx";
1376-
var onnxModelPath = GetOutputPath(onnxFileName);
1377-
1378-
SaveOnnxModel(onnxModel, onnxModelPath, null);
1379-
1380-
if (IsOnnxRuntimeSupported())
1381-
{
1382-
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1383-
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
1384-
var onnxTransformer = onnxEstimator.Fit(dataView);
1385-
var onnxResult = onnxTransformer.Transform(dataView);
1386-
CompareSelectedColumns<ReadOnlyMemory<Char>>("LabelValue", "LabelValue", transformedData, onnxResult);
1387-
}
1388-
1389-
Done();
1390-
}
1391-
13921400
[Fact]
13931401
public void MulticlassTrainersOnnxConversionTest()
13941402
{

0 commit comments

Comments
 (0)