diff --git a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs index 9ef81f90474..0d866066275 100644 --- a/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs +++ b/src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs @@ -197,14 +197,32 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo) outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo)); - //Check if "Probability" column was generated by the base class, only then - //label can be predicted. + /* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise, + the predicted label is based on the sign of the score. + REVIEW: Binarizer should always have at least two output columns? + */ + string opType = "Binarizer"; + var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true); + if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2])) { - string opType = "Binarizer"; - var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) }, - new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType)); + var node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType)); node.AddAttribute("threshold", 0.5); + + opType = "Cast"; + node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType(); + node.AddAttribute("to", t); + } + else if (Bindings.InfoCount == 2) + { + var node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType)); + node.AddAttribute("threshold", 0.0); + + opType = "Cast"; + node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); + var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType(); + node.AddAttribute("to", t); } } diff --git a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs index b7f6c4da871..1367249c480 100644 --- a/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs +++ b/src/Microsoft.ML.Data/Scorers/SchemaBindablePredictorWrapper.cs @@ -320,7 +320,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema if (!ctx.ContainsColumn(featName)) return false; Contracts.Assert(ctx.ContainsColumn(featName)); - return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)); + return mapper.SaveAsOnnx(ctx, new[] { outputNames[1] }, ctx.GetVariableName(featName)); } private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) => diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 0a79681b0b7..3cd22a8e025 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -131,6 +131,14 @@ private class BreastCancerMulticlassExample [LoadColumn(2, 9), VectorType(8)] public float[] Features; } + private class BreastCancerBinaryClassification + { + [LoadColumn(0)] + public bool Label; + + [LoadColumn(2, 9), VectorType(8)] + public float[] Features; + } [LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087")] public void KmeansOnnxConversionTest() @@ -187,6 +195,55 @@ public void KmeansOnnxConversionTest() Done(); } + [Fact] + public void binaryClassificationTrainersOnnxConversionTest() + { + var mlContext = new MLContext(seed: 1); + string dataPath = GetDataPath("breast-cancer.txt"); + // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). + var dataView = mlContext.Data.LoadFromTextFile(dataPath, separatorChar: '\t', hasHeader: true); + IEstimator[] estimators = { + mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), + mlContext.BinaryClassification.Trainers.SgdCalibrated(), + mlContext.BinaryClassification.Trainers.AveragedPerceptron(), + mlContext.BinaryClassification.Trainers.FastForest(), + mlContext.BinaryClassification.Trainers.LinearSvm(), + mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(), + mlContext.BinaryClassification.Trainers.SgdNonCalibrated(), + mlContext.BinaryClassification.Trainers.FastTree(), + mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), + mlContext.BinaryClassification.Trainers.LightGbm(), + mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(), + mlContext.BinaryClassification.Trainers.SgdCalibrated(), + mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), + }; + var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features"). + Append(mlContext.Transforms.NormalizeMinMax("Features")); + foreach (var estimator in estimators) + { + var pipeline = initialPipeline.Append(estimator); + var model = pipeline.Fit(dataView); + var transformedData = model.Transform(dataView); + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); + // Compare model scores produced by ML.NET and ONNX's runtime. + if (IsOnnxRuntimeSupported()) + { + var onnxFileName = $"{estimator.ToString()}.onnx"; + var onnxModelPath = GetOutputPath(onnxFileName); + SaveOnnxModel(onnxModel, onnxModelPath, null); + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); + var onnxTransformer = onnxEstimator.Fit(dataView); + var onnxResult = onnxTransformer.Transform(dataView); + CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); + CompareSelectedScalarColumns(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); + } + + } + Done(); + } private class DataPoint { [VectorType(3)] @@ -853,7 +910,8 @@ private void CreateDummyExamplesToMakeComplierHappy() var dummyExample = new BreastCancerFeatureVector() { Features = null }; var dummyExample1 = new BreastCancerCatFeatureExample() { Label = false, F1 = 0, F2 = "Amy" }; var dummyExample2 = new BreastCancerMulticlassExample() { Label = "Amy", Features = null }; - var dummyExample3 = new SmallSentimentExample() { Tokens = null }; + var dummyExample3 = new BreastCancerBinaryClassification() { Label = false, Features = null }; + var dummyExample4 = new SmallSentimentExample() { Tokens = null }; } private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right) @@ -984,7 +1042,34 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC // Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction. Assert.Equal(1, actual.Length); - Assert.Equal(expected, actual.GetItemOrDefault(0), precision); + CompareNumbersWithTolerance(expected, actual.GetItemOrDefault(0), null, precision); + } + } + } + private void CompareSelectedScalarColumns(string leftColumnName, string rightColumnName, IDataView left, IDataView right) + { + var leftColumn = left.Schema[leftColumnName]; + var rightColumn = right.Schema[rightColumnName]; + + using (var expectedCursor = left.GetRowCursor(leftColumn)) + using (var actualCursor = right.GetRowCursor(rightColumn)) + { + T expected = default; + VBuffer actual = default; + var expectedGetter = expectedCursor.GetGetter(leftColumn); + var actualGetter = actualCursor.GetGetter>(rightColumn); + while (expectedCursor.MoveNext() && actualCursor.MoveNext()) + { + expectedGetter(ref expected); + actualGetter(ref actual); + var actualVal = actual.GetItemOrDefault(0); + + Assert.Equal(1, actual.Length); + + if (typeof(T) == typeof(ReadOnlyMemory)) + Assert.Equal(expected.ToString(), actualVal.ToString()); + else + Assert.Equal(expected, actualVal); } } }