Skip to content

Commit 25ebe4f

Browse files
authored
Alternate solution for ColumnConcatenatingTransformer (#4875)
* alternate solution for concat * adding baselines
1 parent 4493397 commit 25ebe4f

File tree

5 files changed

+172
-17
lines changed

5 files changed

+172
-17
lines changed

src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -896,15 +896,14 @@ public void SaveAsOnnx(OnnxContext ctx)
896896
Host.CheckValue(ctx, nameof(ctx));
897897
Contracts.Assert(CanSaveOnnx(ctx));
898898

899-
string opType = "Concat";
900899
for (int iinfo = 0; iinfo < _columns.Length; ++iinfo)
901900
{
902901
var colInfo = _parent._columns[iinfo];
903902
var boundCol = _columns[iinfo];
904903

905904
string outName = colInfo.Name;
906905
var outColType = boundCol.OutputType;
907-
if (!outColType.IsKnownSize)
906+
if ((!outColType.IsKnownSize) || (!(outColType.GetItemType() is NumberDataViewType)))
908907
{
909908
ctx.RemoveColumn(outName, false);
910909
continue;
@@ -925,10 +924,19 @@ public void SaveAsOnnx(OnnxContext ctx)
925924
InputSchema[srcIndex].Type.GetValueCount()));
926925
}
927926

927+
string opType = "FeatureVectorizer";
928+
int outVectorSize = (int)inputList.Sum(x => x.Value);
929+
var vectorizerOutputType = new VectorDataViewType(NumberDataViewType.Single, outVectorSize);
930+
var vectorizerOutputName = ctx.AddIntermediateVariable(vectorizerOutputType, "VectorFeaturizerOutput");
928931
var node = ctx.CreateNode(opType, inputList.Select(t => t.Key),
929-
new[] { ctx.AddIntermediateVariable(outColType, outName) }, ctx.GetNodeName(opType), "");
930-
931-
node.AddAttribute("axis", 1);
932+
new[] { vectorizerOutputName }, ctx.GetNodeName(opType));
933+
node.AddAttribute("inputdimensions", inputList.Select(x => x.Value));
934+
935+
opType = "Cast";
936+
var dstVectorType = new VectorDataViewType(outColType.GetItemType() as PrimitiveDataViewType, outVectorSize);
937+
var dstVariableName = ctx.AddIntermediateVariable(dstVectorType, outName);
938+
var castNode = ctx.CreateNode(opType, vectorizerOutputName, dstVariableName, ctx.GetNodeName(opType), "");
939+
castNode.AddAttribute("to", outColType.ItemType.RawType);
932940
}
933941
}
934942
}

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,35 @@
176176
"F1",
177177
"F22"
178178
],
179+
"output": [
180+
"VectorFeaturizerOutput"
181+
],
182+
"name": "FeatureVectorizer",
183+
"opType": "FeatureVectorizer",
184+
"attribute": [
185+
{
186+
"name": "inputdimensions",
187+
"ints": [
188+
"1",
189+
"10"
190+
],
191+
"type": "INTS"
192+
}
193+
],
194+
"domain": "ai.onnx.ml"
195+
},
196+
{
197+
"input": [
198+
"VectorFeaturizerOutput"
199+
],
179200
"output": [
180201
"Features"
181202
],
182-
"name": "Concat",
183-
"opType": "Concat",
203+
"name": "Cast1",
204+
"opType": "Cast",
184205
"attribute": [
185206
{
186-
"name": "axis",
207+
"name": "to",
187208
"i": "1",
188209
"type": "INT"
189210
}
@@ -431,7 +452,7 @@
431452
"output": [
432453
"PredictedLabel"
433454
],
434-
"name": "Cast1",
455+
"name": "Cast2",
435456
"opType": "Cast",
436457
"attribute": [
437458
{
@@ -638,6 +659,24 @@
638659
}
639660
}
640661
},
662+
{
663+
"name": "VectorFeaturizerOutput",
664+
"type": {
665+
"tensorType": {
666+
"elemType": 1,
667+
"shape": {
668+
"dim": [
669+
{
670+
"dimValue": "-1"
671+
},
672+
{
673+
"dimValue": "11"
674+
}
675+
]
676+
}
677+
}
678+
}
679+
},
641680
{
642681
"name": "Features",
643682
"type": {

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ModelWithLessIO.txt

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,35 @@
125125
"F1",
126126
"F21"
127127
],
128+
"output": [
129+
"VectorFeaturizerOutput"
130+
],
131+
"name": "FeatureVectorizer",
132+
"opType": "FeatureVectorizer",
133+
"attribute": [
134+
{
135+
"name": "inputdimensions",
136+
"ints": [
137+
"8",
138+
"9"
139+
],
140+
"type": "INTS"
141+
}
142+
],
143+
"domain": "ai.onnx.ml"
144+
},
145+
{
146+
"input": [
147+
"VectorFeaturizerOutput"
148+
],
128149
"output": [
129150
"Features"
130151
],
131-
"name": "Concat",
132-
"opType": "Concat",
152+
"name": "Cast1",
153+
"opType": "Cast",
133154
"attribute": [
134155
{
135-
"name": "axis",
156+
"name": "to",
136157
"i": "1",
137158
"type": "INT"
138159
}
@@ -757,7 +778,7 @@
757778
"output": [
758779
"PredictedLabel"
759780
],
760-
"name": "Cast1",
781+
"name": "Cast2",
761782
"opType": "Cast",
762783
"attribute": [
763784
{
@@ -946,6 +967,24 @@
946967
}
947968
}
948969
},
970+
{
971+
"name": "VectorFeaturizerOutput",
972+
"type": {
973+
"tensorType": {
974+
"elemType": 1,
975+
"shape": {
976+
"dim": [
977+
{
978+
"dimValue": "-1"
979+
},
980+
{
981+
"dimValue": "17"
982+
}
983+
]
984+
}
985+
}
986+
}
987+
},
949988
{
950989
"name": "Features",
951990
"type": {

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/OneHotBagPipeline.txt

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,35 @@
176176
"F1",
177177
"F22"
178178
],
179+
"output": [
180+
"VectorFeaturizerOutput"
181+
],
182+
"name": "FeatureVectorizer",
183+
"opType": "FeatureVectorizer",
184+
"attribute": [
185+
{
186+
"name": "inputdimensions",
187+
"ints": [
188+
"1",
189+
"10"
190+
],
191+
"type": "INTS"
192+
}
193+
],
194+
"domain": "ai.onnx.ml"
195+
},
196+
{
197+
"input": [
198+
"VectorFeaturizerOutput"
199+
],
179200
"output": [
180201
"Features"
181202
],
182-
"name": "Concat",
183-
"opType": "Concat",
203+
"name": "Cast1",
204+
"opType": "Cast",
184205
"attribute": [
185206
{
186-
"name": "axis",
207+
"name": "to",
187208
"i": "1",
188209
"type": "INT"
189210
}
@@ -384,7 +405,7 @@
384405
"output": [
385406
"PredictedLabel"
386407
],
387-
"name": "Cast1",
408+
"name": "Cast2",
388409
"opType": "Cast",
389410
"attribute": [
390411
{
@@ -871,6 +892,24 @@
871892
}
872893
}
873894
},
895+
{
896+
"name": "VectorFeaturizerOutput",
897+
"type": {
898+
"tensorType": {
899+
"elemType": 1,
900+
"shape": {
901+
"dim": [
902+
{
903+
"dimValue": "-1"
904+
},
905+
{
906+
"dimValue": "11"
907+
}
908+
]
909+
}
910+
}
911+
}
912+
},
874913
{
875914
"name": "Features",
876915
"type": {

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,6 +878,36 @@ public void LoadingPredictorModelAndOnnxConversionTest()
878878
Done();
879879
}
880880

881+
[Fact]
882+
public void ConcatenateOnnxConversionTest()
883+
{
884+
var mlContext = new MLContext(seed: 1);
885+
string dataPath = GetDataPath("breast-cancer.txt");
886+
887+
var data = ML.Data.LoadFromTextFile(dataPath, new[] {
888+
new TextLoader.Column("VectorDouble2", DataKind.Double, 1),
889+
new TextLoader.Column("VectorDouble1", DataKind.Double, 4, 8),
890+
new TextLoader.Column("Label", DataKind.Boolean, 0)
891+
});
892+
var pipeline = mlContext.Transforms.Concatenate("Features", "VectorDouble1", "VectorDouble2");
893+
var model = pipeline.Fit(data);
894+
var transformedData = model.Transform(data);
895+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
896+
897+
// Compare results produced by ML.NET and ONNX's runtime.
898+
if (IsOnnxRuntimeSupported())
899+
{
900+
var onnxModelName = "Concatenate.onnx";
901+
var onnxModelPath = GetOutputPath(onnxModelName);
902+
SaveOnnxModel(onnxModel, onnxModelPath, null);
903+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
904+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
905+
var onnxTransformer = onnxEstimator.Fit(data);
906+
var onnxResult = onnxTransformer.Transform(data);
907+
CompareSelectedColumns<double>("Features", "Features", transformedData, onnxResult);
908+
}
909+
Done();
910+
}
881911

882912
[Fact]
883913
public void RemoveVariablesInPipelineTest()

0 commit comments

Comments
 (0)