Skip to content

Commit 1a3df98

Browse files
authored
Adding OneHotHashEncoding Test (#5098)
1 parent 59dbdea commit 1a3df98

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

test/BaselineOutput/Common/Onnx/BinaryClassification/BreastCancer/ExcludeVariablesInOnnxConversion.txt

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,14 +447,24 @@
447447
}
448448
]
449449
},
450+
{
451+
"input": [
452+
"F3"
453+
],
454+
"output": [
455+
"F3.output"
456+
],
457+
"name": "Identity",
458+
"opType": "Identity"
459+
},
450460
{
451461
"input": [
452462
"PredictedLabel"
453463
],
454464
"output": [
455465
"PredictedLabel.output"
456466
],
457-
"name": "Identity",
467+
"name": "Identity0",
458468
"opType": "Identity"
459469
},
460470
{
@@ -464,7 +474,7 @@
464474
"output": [
465475
"Score.output"
466476
],
467-
"name": "Identity0",
477+
"name": "Identity1",
468478
"opType": "Identity"
469479
},
470480
{
@@ -474,7 +484,7 @@
474484
"output": [
475485
"Probability.output"
476486
],
477-
"name": "Identity1",
487+
"name": "Identity2",
478488
"opType": "Identity"
479489
}
480490
],
@@ -531,9 +541,45 @@
531541
}
532542
}
533543
}
544+
},
545+
{
546+
"name": "F3",
547+
"type": {
548+
"tensorType": {
549+
"elemType": 8,
550+
"shape": {
551+
"dim": [
552+
{
553+
"dimValue": "-1"
554+
},
555+
{
556+
"dimValue": "5"
557+
}
558+
]
559+
}
560+
}
561+
}
534562
}
535563
],
536564
"output": [
565+
{
566+
"name": "F3.output",
567+
"type": {
568+
"tensorType": {
569+
"elemType": 8,
570+
"shape": {
571+
"dim": [
572+
{
573+
"dimValue": "-1"
574+
},
575+
{
576+
"dimValue": "5"
577+
}
578+
]
579+
}
580+
}
581+
}
582+
},
537583
{
538584
"name": "PredictedLabel.output",
539585
"type": {
@@ -806,6 +852,24 @@
806852
}
807853
}
808854
},
855+
{
856+
"name": "F3.output",
857+
"type": {
858+
"tensorType": {
859+
"elemType": 8,
860+
"shape": {
861+
"dim": [
862+
{
863+
"dimValue": "-1"
864+
},
865+
{
866+
"dimValue": "5"
867+
}
868+
]
869+
}
870+
}
871+
}
872+
},
809873
{
810874
"name": "PredictedLabel.output",
811875
"type": {

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
using System.Runtime.InteropServices;
1010
using System.Text.RegularExpressions;
1111
using Google.Protobuf;
12-
using Google.Protobuf.WellKnownTypes;
1312
using Microsoft.ML.Data;
1413
using Microsoft.ML.EntryPoints;
1514
using Microsoft.ML.Model.OnnxConverter;
@@ -125,6 +124,9 @@ private class BreastCancerCatFeatureExample
125124

126125
[LoadColumn(2)]
127126
public string F2;
127+
128+
[LoadColumn(3, 7), VectorType(6)]
129+
public string[] F3;
128130
}
129131

130132
private class BreastCancerMulticlassExample
@@ -1162,6 +1164,37 @@ public void PcaOnnxConversionTest()
11621164
Done();
11631165
}
11641166

1167+
[Fact]
1168+
public void OneHotHashEncodingOnnxConversionTest()
1169+
{
1170+
var mlContext = new MLContext();
1171+
string dataPath = GetDataPath("breast-cancer.txt");
1172+
1173+
var dataView = ML.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
1174+
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
1175+
new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false),
1176+
});
1177+
var model = pipe.Fit(dataView);
1178+
var transformedData = model.Transform(dataView);
1179+
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
1180+
1181+
var onnxFileName = "OneHotHashEncoding.onnx";
1182+
var onnxModelPath = GetOutputPath(onnxFileName);
1183+
SaveOnnxModel(onnxModel, onnxModelPath, null);
1184+
1185+
if (IsOnnxRuntimeSupported())
1186+
{
1187+
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
1188+
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1189+
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
1190+
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
1191+
var onnxTransformer = onnxEstimator.Fit(dataView);
1192+
var onnxResult = onnxTransformer.Transform(dataView);
1193+
CompareSelectedColumns<float>("Output", "Output", transformedData, onnxResult);
1194+
}
1195+
Done();
1196+
}
1197+
11651198
[Theory]
11661199
[CombinatorialData]
11671200
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.

0 commit comments

Comments
 (0)