Skip to content

Commit

Permalink
Adding OneHotHashEncoding Test (#5098)
Browse files Browse the repository at this point in the history
  • Loading branch information
Lynx1820 authored May 9, 2020
1 parent 59dbdea commit 1a3df98
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,24 @@
}
]
},
{
"input": [
"F3"
],
"output": [
"F3.output"
],
"name": "Identity",
"opType": "Identity"
},
{
"input": [
"PredictedLabel"
],
"output": [
"PredictedLabel.output"
],
"name": "Identity",
"name": "Identity0",
"opType": "Identity"
},
{
Expand All @@ -464,7 +474,7 @@
"output": [
"Score.output"
],
"name": "Identity0",
"name": "Identity1",
"opType": "Identity"
},
{
Expand All @@ -474,7 +484,7 @@
"output": [
"Probability.output"
],
"name": "Identity1",
"name": "Identity2",
"opType": "Identity"
}
],
Expand Down Expand Up @@ -531,9 +541,45 @@
}
}
}
},
{
"name": "F3",
"type": {
"tensorType": {
"elemType": 8,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "5"
}
]
}
}
}
}
],
"output": [
{
"name": "F3.output",
"type": {
"tensorType": {
"elemType": 8,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "5"
}
]
}
}
}
},
{
"name": "PredictedLabel.output",
"type": {
Expand Down Expand Up @@ -806,6 +852,24 @@
}
}
},
{
"name": "F3.output",
"type": {
"tensorType": {
"elemType": 8,
"shape": {
"dim": [
{
"dimValue": "-1"
},
{
"dimValue": "5"
}
]
}
}
}
},
{
"name": "PredictedLabel.output",
"type": {
Expand Down
35 changes: 34 additions & 1 deletion test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using Google.Protobuf;
using Google.Protobuf.WellKnownTypes;
using Microsoft.ML.Data;
using Microsoft.ML.EntryPoints;
using Microsoft.ML.Model.OnnxConverter;
Expand Down Expand Up @@ -125,6 +124,9 @@ private class BreastCancerCatFeatureExample

[LoadColumn(2)]
public string F2;

[LoadColumn(3, 7), VectorType(6)]
public string[] F3;
}

private class BreastCancerMulticlassExample
Expand Down Expand Up @@ -1162,6 +1164,37 @@ public void PcaOnnxConversionTest()
Done();
}

[Fact]
public void OneHotHashEncodingOnnxConversionTest()
{
var mlContext = new MLContext();
string dataPath = GetDataPath("breast-cancer.txt");

var dataView = ML.Data.LoadFromTextFile<BreastCancerCatFeatureExample>(dataPath);
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
new OneHotHashEncodingEstimator.ColumnOptions("Output", "F3", useOrderedHashing:false),
});
var model = pipe.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = "OneHotHashEncoding.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedColumns<float>("Output", "Output", transformedData, onnxResult);
}
Done();
}

[Theory]
[CombinatorialData]
// Due to lack of Onnxruntime support, long/ulong, double, floats, and OrderedHashing are not supported.
Expand Down

0 comments on commit 1a3df98

Please sign in to comment.