From fab2e1d6c6fc9f2aa55f062da7a928e2380339c5 Mon Sep 17 00:00:00 2001
From: Mustafa Bal <5262061+mstfbl@users.noreply.github.com>
Date: Fri, 10 Jul 2020 19:21:06 -0700
Subject: [PATCH] Added ONNX export support and tests for {FixedPlatt,
Naive}CalibratorEstimators (#5289)
* Added ONNX export tests for other calibrators
* Consolitated testing, started ONNX model conversion
* Added ONNX export support for NaiveCalibrator
* Enable NaiveCalibratorOnnxConversionTest
* Work on Isotonic Calibrator ONNX export support
* Removed Isotonic work from this PR
* Nit correct spacing
* Nit renaming to CalibratorInput(NonStandard)
* Organized tests
* Nit
* Clean-up initialization of vars for CommonCalibratorOnnxConversionTest
* Removed MLContexts for ML
---
.../Prediction/Calibrator.cs | 47 ++++++-
test/Microsoft.ML.Tests/OnnxConversionTest.cs | 123 ++++++++++--------
2 files changed, 114 insertions(+), 56 deletions(-)
diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
index d1948c0cb4..375e2a8fb1 100644
--- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs
+++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs
@@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch)
///
/// The naive binning-based calibrator.
///
- public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat
+ public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "NaiveCaliExec";
internal const string RegistrationName = "NaiveCalibrator";
@@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
}
+ ///
+ /// Bool required by the interface ISingleCanSaveOnnx, returns true if
+ /// and only if calibrator can be exported in ONNX.
+ ///
+ bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
+
private readonly IHost _host;
/// The bin size.
@@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
return binIdx;
}
+ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
+ {
+ _host.CheckValue(ctx, nameof(ctx));
+ _host.CheckValue(outputNames, nameof(outputNames));
+ _host.Check(Utils.Size(outputNames) == 2);
+ // outputNames[0] refers to the name of the Score column, which is the input of this graph
+ // outputNames[1] refers to the name of the Probability column, which is the final output of this graph
+
+ const int minimumOpSetVersion = 9;
+ ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator");
+
+ string opType = "Sub";
+ var minVar = ctx.AddInitializer(Min, "Min");
+ var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput");
+ var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Div";
+ var binSizeVar = ctx.AddInitializer(BinSize, "BinSize");
+ var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput");
+ node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "Cast";
+ var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput");
+ node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), "");
+ var toTypeInt = typeof(long);
+ node.AddAttribute("to", toTypeInt);
+
+ opType = "Clip";
+ var zeroVar = ctx.AddInitializer(0, "Zero");
+ var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne");
+ var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput");
+ node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), "");
+
+ opType = "GatherElements";
+ var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities");
+ node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
+
+ return true;
+ }
}
///
diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
index dfd6aa2136..56d50ac702 100644
--- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs
+++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs
@@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}
- [Fact]
- public void PlattCalibratorOnnxConversionTest()
+ private (IDataView, List>, EstimatorChain) GetEstimatorsForOnnxConversionTests()
{
- var mlContext = new MLContext(seed: 1);
- string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
+ string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
- var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: false);
+ var dataView = ML.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true);
List> estimators = new List>()
{
- mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
- mlContext.BinaryClassification.Trainers.FastForest(),
- mlContext.BinaryClassification.Trainers.FastTree(),
- mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
- mlContext.BinaryClassification.Trainers.LinearSvm(),
- mlContext.BinaryClassification.Trainers.Prior(),
- mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
- mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
- mlContext.BinaryClassification.Trainers.SgdCalibrated(),
- mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
- mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
+ ML.BinaryClassification.Trainers.AveragedPerceptron(),
+ ML.BinaryClassification.Trainers.FastForest(),
+ ML.BinaryClassification.Trainers.FastTree(),
+ ML.BinaryClassification.Trainers.LbfgsLogisticRegression(),
+ ML.BinaryClassification.Trainers.LinearSvm(),
+ ML.BinaryClassification.Trainers.Prior(),
+ ML.BinaryClassification.Trainers.SdcaLogisticRegression(),
+ ML.BinaryClassification.Trainers.SdcaNonCalibrated(),
+ ML.BinaryClassification.Trainers.SgdCalibrated(),
+ ML.BinaryClassification.Trainers.SgdNonCalibrated(),
+ ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
- estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
+ estimators.Add(ML.BinaryClassification.Trainers.LightGbm());
}
- var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
- Append(mlContext.Transforms.NormalizeMinMax("Features"));
+ var initialPipeline = ML.Transforms.ReplaceMissingValues("Features").
+ Append(ML.Transforms.NormalizeMinMax("Features"));
+ return (dataView, estimators, initialPipeline);
+ }
+
+ private void CommonCalibratorOnnxConversionTest(IEstimator calibrator, IEstimator calibratorNonStandard)
+ {
+ // Initialize variables needed for the ONNX conversion test
+ var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
+
+ // Step 1: Test calibrator with binary prediction trainer
foreach (var estimator in estimators)
{
- var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
- var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx";
-
- TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
+ var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator);
+ var onnxFileName = $"{estimator}-With-{calibrator}.onnx";
+ TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}
+
+ // Step 2: Test calibrator without any binary prediction trainer
+ IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData());
+ var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx";
+ TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
+
+ // Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
+ IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard());
+ var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx";
+ TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
+
Done();
}
- class PlattModelInput
+ [Fact]
+ public void PlattCalibratorOnnxConversionTest()
+ {
+ CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(),
+ ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX"));
+ }
+
+ [Fact]
+ public void FixedPlattCalibratorOnnxConversionTest()
+ {
+ // Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
+ CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f),
+ ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX"));
+ }
+
+ [Fact]
+ public void NaiveCalibratorOnnxConversionTest()
+ {
+ CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(),
+ ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
+ }
+
+ class CalibratorInput
{
public bool Label { get; set; }
public float Score { get; set; }
}
- class PlattModelInput2
+ class CalibratorInputNonStandard
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}
- static IEnumerable PlattGetData()
+ static IEnumerable GetCalibratorTestData()
{
for (int i = 0; i < 100; i++)
{
- yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
+ yield return new CalibratorInput { Score = i, Label = i % 2 == 0 };
}
}
- static IEnumerable PlattGetData2()
+ static IEnumerable GetCalibratorTestDataNonStandard()
{
for (int i = 0; i < 100; i++)
{
- yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
+ yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 };
}
}
- [Fact]
- public void PlattCalibratorOnnxConversionTest2()
- {
- // Test PlattCalibrator without any binary prediction trainer
- var mlContext = new MLContext(seed: 0);
-
- IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());
-
- var pipeline = mlContext.BinaryClassification.Calibrators
- .Platt();
- var onnxFileName = $"{pipeline}.onnx";
-
- TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
-
- // Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
- IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());
-
- var pipeline2 = mlContext.BinaryClassification.Calibrators
- .Platt(scoreColumnName: "ScoreX");
- var onnxFileName2 = $"{pipeline2}.onnx";
-
- TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
-
- Done();
- }
-
[Fact]
public void TextNormalizingOnnxConversionTest()
{