Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Onnx Export to PlattCalibratorTransformer #4699

Merged
merged 10 commits into from
Jan 27, 2020
40 changes: 37 additions & 3 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,40 @@ public void BinaryClassificationTrainersOnnxConversionTest()
Done();
}

[Fact]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be part of your changes. Maybe you did your merge steps wrong?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think I messed up, but it's fixed now

public void TestVectorWhiteningOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
var dataView = mlContext.Data.LoadFromTextFile(dataPath, new[] {
new TextLoader.Column("label", DataKind.Single, 11),
new TextLoader.Column("features", DataKind.Single, 0, 10)
}, hasHeader: true, separatorChar: ';');

var pipeline = new VectorWhiteningEstimator(mlContext, "whitened1", "features")
.Append(new VectorWhiteningEstimator(mlContext, "whitened2", "features", kind: WhiteningKind.PrincipalComponentAnalysis, rank: 5));
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"VectorWhitening.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4VectorColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult); // whitened1
CompareSelectedR4VectorColumns(transformedData.Schema[3].Name, outputNames[3], transformedData, onnxResult); // whitened2
}
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
{
Expand Down Expand Up @@ -345,7 +379,7 @@ public void PlattCalibratorOnnxConversionTest()
var outputSchema = model.GetOutputSchema(dataView.Schema);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
Expand Down Expand Up @@ -374,7 +408,7 @@ class PlattModelInput

static IEnumerable<PlattModelInput> PlattGetData()
{
for(int i = 0; i < 100; i++)
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
}
Expand All @@ -394,7 +428,7 @@ public void PlattCalibratorOnnxConversionTest2()
var calibratorTransformer = calibratorEstimator.Fit(data);
var transformedData = calibratorTransformer.Transform(data);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(calibratorTransformer, data);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
Expand Down