From fab2e1d6c6fc9f2aa55f062da7a928e2380339c5 Mon Sep 17 00:00:00 2001 From: Mustafa Bal <5262061+mstfbl@users.noreply.github.com> Date: Fri, 10 Jul 2020 19:21:06 -0700 Subject: [PATCH] Added ONNX export support and tests for {FixedPlatt, Naive}CalibratorEstimators (#5289) * Added ONNX export tests for other calibrators * Consolitated testing, started ONNX model conversion * Added ONNX export support for NaiveCalibrator * Enable NaiveCalibratorOnnxConversionTest * Work on Isotonic Calibrator ONNX export support * Removed Isotonic work from this PR * Nit correct spacing * Nit renaming to CalibratorInput(NonStandard) * Organized tests * Nit * Clean-up initialization of vars for CommonCalibratorOnnxConversionTest * Removed MLContexts for ML --- .../Prediction/Calibrator.cs | 47 ++++++- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 123 ++++++++++-------- 2 files changed, 114 insertions(+), 56 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index d1948c0cb4..375e2a8fb1 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch) /// /// The naive binning-based calibrator. /// - public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat + public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx { internal const string LoaderSignature = "NaiveCaliExec"; internal const string RegistrationName = "NaiveCalibrator"; @@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName); } + /// + /// Bool required by the interface ISingleCanSaveOnnx, returns true if + /// and only if calibrator can be exported in ONNX. + /// + bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true; + private readonly IHost _host; /// The bin size. @@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin return binIdx; } + bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) + { + _host.CheckValue(ctx, nameof(ctx)); + _host.CheckValue(outputNames, nameof(outputNames)); + _host.Check(Utils.Size(outputNames) == 2); + // outputNames[0] refers to the name of the Score column, which is the input of this graph + // outputNames[1] refers to the name of the Probability column, which is the final output of this graph + + const int minimumOpSetVersion = 9; + ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator"); + + string opType = "Sub"; + var minVar = ctx.AddInitializer(Min, "Min"); + var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput"); + var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), ""); + + opType = "Div"; + var binSizeVar = ctx.AddInitializer(BinSize, "BinSize"); + var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput"); + node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), ""); + + opType = "Cast"; + var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput"); + node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), ""); + var toTypeInt = typeof(long); + node.AddAttribute("to", toTypeInt); + + opType = "Clip"; + var zeroVar = ctx.AddInitializer(0, "Zero"); + var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne"); + var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput"); + node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), ""); + + opType = "GatherElements"; + var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities"); + node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), ""); + + return true; + } } /// diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index dfd6aa2136..56d50ac702 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest() Done(); } - [Fact] - public void PlattCalibratorOnnxConversionTest() + private (IDataView, List>, EstimatorChain) GetEstimatorsForOnnxConversionTests() { - var mlContext = new MLContext(seed: 1); - string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); + string dataPath = GetDataPath("breast-cancer.txt"); // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). - var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: false); + var dataView = ML.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true); List> estimators = new List>() { - mlContext.BinaryClassification.Trainers.AveragedPerceptron(), - mlContext.BinaryClassification.Trainers.FastForest(), - mlContext.BinaryClassification.Trainers.FastTree(), - mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), - mlContext.BinaryClassification.Trainers.LinearSvm(), - mlContext.BinaryClassification.Trainers.Prior(), - mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(), - mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(), - mlContext.BinaryClassification.Trainers.SgdCalibrated(), - mlContext.BinaryClassification.Trainers.SgdNonCalibrated(), - mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), + ML.BinaryClassification.Trainers.AveragedPerceptron(), + ML.BinaryClassification.Trainers.FastForest(), + ML.BinaryClassification.Trainers.FastTree(), + ML.BinaryClassification.Trainers.LbfgsLogisticRegression(), + ML.BinaryClassification.Trainers.LinearSvm(), + ML.BinaryClassification.Trainers.Prior(), + ML.BinaryClassification.Trainers.SdcaLogisticRegression(), + ML.BinaryClassification.Trainers.SdcaNonCalibrated(), + ML.BinaryClassification.Trainers.SgdCalibrated(), + ML.BinaryClassification.Trainers.SgdNonCalibrated(), + ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), }; if (Environment.Is64BitProcess) { - estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm()); + estimators.Add(ML.BinaryClassification.Trainers.LightGbm()); } - var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features"). - Append(mlContext.Transforms.NormalizeMinMax("Features")); + var initialPipeline = ML.Transforms.ReplaceMissingValues("Features"). + Append(ML.Transforms.NormalizeMinMax("Features")); + return (dataView, estimators, initialPipeline); + } + + private void CommonCalibratorOnnxConversionTest(IEstimator calibrator, IEstimator calibratorNonStandard) + { + // Initialize variables needed for the ONNX conversion test + var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests(); + + // Step 1: Test calibrator with binary prediction trainer foreach (var estimator in estimators) { - var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt()); - var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx"; - - TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) }); + var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator); + var onnxFileName = $"{estimator}-With-{calibrator}.onnx"; + TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) }); } + + // Step 2: Test calibrator without any binary prediction trainer + IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData()); + var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx"; + TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); + + // Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer + IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard()); + var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx"; + TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); + Done(); } - class PlattModelInput + [Fact] + public void PlattCalibratorOnnxConversionTest() + { + CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(), + ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX")); + } + + [Fact] + public void FixedPlattCalibratorOnnxConversionTest() + { + // Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values. + CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f), + ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX")); + } + + [Fact] + public void NaiveCalibratorOnnxConversionTest() + { + CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(), + ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX")); + } + + class CalibratorInput { public bool Label { get; set; } public float Score { get; set; } } - class PlattModelInput2 + class CalibratorInputNonStandard { public bool Label { get; set; } public float ScoreX { get; set; } } - static IEnumerable PlattGetData() + static IEnumerable GetCalibratorTestData() { for (int i = 0; i < 100; i++) { - yield return new PlattModelInput { Score = i, Label = i % 2 == 0 }; + yield return new CalibratorInput { Score = i, Label = i % 2 == 0 }; } } - static IEnumerable PlattGetData2() + static IEnumerable GetCalibratorTestDataNonStandard() { for (int i = 0; i < 100; i++) { - yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 }; + yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 }; } } - [Fact] - public void PlattCalibratorOnnxConversionTest2() - { - // Test PlattCalibrator without any binary prediction trainer - var mlContext = new MLContext(seed: 0); - - IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData()); - - var pipeline = mlContext.BinaryClassification.Calibrators - .Platt(); - var onnxFileName = $"{pipeline}.onnx"; - - TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); - - // Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer - IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2()); - - var pipeline2 = mlContext.BinaryClassification.Calibrators - .Platt(scoreColumnName: "ScoreX"); - var onnxFileName2 = $"{pipeline2}.onnx"; - - TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); - - Done(); - } - [Fact] public void TextNormalizingOnnxConversionTest() {