Skip to content

Commit 3476f04

Browse files
authored
Fix for ColumnConcatenatingTransformer (#4861)
* using concat for columnconcatenation * resolving comments
1 parent 275f4c2 commit 3476f04

File tree

5 files changed

+56
-46
lines changed

5 files changed

+56
-46
lines changed

src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ public void SaveAsOnnx(OnnxContext ctx)
896896
Host.CheckValue(ctx, nameof(ctx));
897897
Contracts.Assert(CanSaveOnnx(ctx));
898898

899-
string opType = "FeatureVectorizer";
899+
string opType = "Concat";
900900
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
901901
{
902902
var colInfo = _parent._columns[iinfo];
@@ -926,9 +926,9 @@ public void SaveAsOnnx(OnnxContext ctx)
926926
}
927927

928928
var node = ctx.CreateNode(opType, inputList.Select(t => t.Key),
929-
new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType));
929+
new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType), "");
930930

931-
node.AddAttribute("inputdimensions", inputList.Select(x => x.Value));
931+
node.AddAttribute("axis", 1);
932932
}
933933
}
934934
}

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,15 @@
179179
"output": [
180180
"Features"
181181
],
182-
"name": "FeatureVectorizer",
183-
"opType": "FeatureVectorizer",
182+
"name": "Concat",
183+
"opType": "Concat",
184184
"attribute": [
185185
{
186-
"name": "inputdimensions",
187-
"ints": [
188-
"1",
189-
"10"
190-
],
191-
"type": "INTS"
186+
"name": "axis",
187+
"i": "1",
188+
"type": "INT"
192189
}
193-
],
194-
"domain": "ai.onnx.ml"
190+
]
195191
},
196192
{
197193
"input": [

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,15 @@
128128
"output": [
129129
"Features"
130130
],
131-
"name": "FeatureVectorizer",
132-
"opType": "FeatureVectorizer",
131+
"name": "Concat",
132+
"opType": "Concat",
133133
"attribute": [
134134
{
135-
"name": "inputdimensions",
136-
"ints": [
137-
"8",
138-
"9"
139-
],
140-
"type": "INTS"
135+
"name": "axis",
136+
"i": "1",
137+
"type": "INT"
141138
}
142-
],
143-
"domain": "ai.onnx.ml"
139+
]
144140
},
145141
{
146142
"input": [
@@ -677,27 +673,27 @@
677673
"floats": [
678674
-0.9850374,
679675
-1,
680-
-0.42857143,
676+
-0.428571433,
681677
0.05882353,
682678
0.9655172,
683-
0.47826087,
684-
7E-45,
679+
0.478260875,
680+
7.006492E-45,
685681
0.9354839,
686682
-0.837172,
687-
-0.89662564,
683+
-0.896625638,
688684
-0.3455931,
689-
0.22312601,
685+
0.223126009,
690686
0.8040303,
691687
0.60825175,
692688
-0.06932944,
693-
-0.40204307,
689+
-0.402043074,
694690
-0.7417274,
695-
-0.40843493,
691+
-0.408434927,
696692
0.7105746,
697693
0.1875386,
698694
0.7631735,
699-
0.70617324,
700-
0.62590647,
695+
0.706173241,
696+
0.625906467,
701697
-0.35968104
702698
],
703699
"type": "FLOATS"

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,19 +179,15 @@
179179
"output": [
180180
"Features"
181181
],
182-
"name": "FeatureVectorizer",
183-
"opType": "FeatureVectorizer",
182+
"name": "Concat",
183+
"opType": "Concat",
184184
"attribute": [
185185
{
186-
"name": "inputdimensions",
187-
"ints": [
188-
"1",
189-
"10"
190-
],
191-
"type": "INTS"
186+
"name": "axis",
187+
"i": "1",
188+
"type": "INT"
192189
}
193-
],
194-
"domain": "ai.onnx.ml"
190+
]
195191
},
196192
{
197193
"input": [

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,7 @@ public void KeyToVectorWithBagOnnxConversionTest()
563563
.Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label", featureColumnName: "Features", numberOfLeaves: 2, numberOfTrees: 1, minimumExampleCountPerLeaf: 2));
564564

565565
var model = pipeline.Fit(data);
566+
var transformedData = model.Transform(data);
566567
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
567568

568569
// Check ONNX model's text format. We save the produced ONNX model as a text file and compare it against
@@ -572,8 +573,19 @@ public void KeyToVectorWithBagOnnxConversionTest()
572573
var onnxTextName = "OneHotBagPipeline.txt";
573574
var onnxFileName = "OneHotBagPipeline.onnx";
574575
var onnxTextPath = GetOutputPath(subDir, onnxTextName);
575-
var onnxFilePath = GetOutputPath(subDir, onnxFileName);
576-
SaveOnnxModel(onnxModel, onnxFilePath, onnxTextPath);
576+
var onnxModelPath = GetOutputPath(subDir, onnxFileName);
577+
SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath);
578+
// Compare results produced by ML.NET and ONNX's runtime.
579+
if (IsOnnxRuntimeSupported())
580+
{
581+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
582+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
583+
var onnxTransformer = onnxEstimator.Fit(data);
584+
var onnxResult = onnxTransformer.Transform(data);
585+
CompareSelectedColumns<float>("Score", "Score", transformedData, onnxResult);
586+
CompareSelectedColumns<float>("Probability", "Probability", transformedData, onnxResult);
587+
CompareSelectedColumns<bool>("PredictedLabel", "PredictedLabel", transformedData, onnxResult);
588+
}
577589
CheckEquality(subDir, onnxTextName);
578590
Done();
579591
}
@@ -905,8 +917,18 @@ public void RemoveVariablesInPipelineTest()
905917
var onnxTextName = "ExcludeVariablesInOnnxConversion.txt";
906918
var onnxFileName = "ExcludeVariablesInOnnxConversion.onnx";
907919
var onnxTextPath = GetOutputPath(subDir, onnxTextName);
908-
var onnxFilePath = GetOutputPath(subDir, onnxFileName);
909-
SaveOnnxModel(onnxModel, onnxFilePath, onnxTextPath);
920+
var onnxModelPath = GetOutputPath(subDir, onnxFileName);
921+
SaveOnnxModel(onnxModel, onnxModelPath, onnxTextPath);
922+
if (IsOnnxRuntimeSupported())
923+
{
924+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
925+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
926+
var onnxTransformer = onnxEstimator.Fit(data);
927+
var onnxResult = onnxTransformer.Transform(data);
928+
CompareSelectedColumns<float>("Score", "Score", transformedData, onnxResult);
929+
CompareSelectedColumns<float>("Probability", "Probability", transformedData, onnxResult);
930+
CompareSelectedColumns<bool>("PredictedLabel", "PredictedLabel", transformedData, onnxResult);
931+
}
910932
CheckEquality(subDir, onnxTextName, digitsOfPrecision: 3);
911933
}
912934
Done();

0 commit comments

Comments
 (0)