Skip to content

Commit

Permalink
fix for binary classification trainers export to onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
Lynx1820 committed Nov 11, 2019
1 parent 6fad293 commit ea71828
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 8 deletions.
28 changes: 23 additions & 5 deletions src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,32 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx)
for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));

//Check if "Probability" column was generated by the base class, only then
//label can be predicted.
/* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise,
the predicted label is based on the sign of the score.
REVIEW: Binarizer should always have at least two output columns?
*/
string opType = "Binarizer";
var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);

if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
{
string opType = "Binarizer";
var node = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) },
new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
var node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType));
node.AddAttribute("threshold", 0.5);

opType = "Cast";
node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
node.AddAttribute("to", t);
}
else if (Bindings.InfoCount == 2)
{
var node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType));
node.AddAttribute("threshold", 0.0);

opType = "Cast";
node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();
node.AddAttribute("to", t);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema
if (!ctx.ContainsColumn(featName))
return false;
Contracts.Assert(ctx.ContainsColumn(featName));
return mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName));
return mapper.SaveAsOnnx(ctx, new[] { outputNames[1] }, ctx.GetVariableName(featName));
}

private protected override ISchemaBoundMapper BindCore(IChannel ch, RoleMappedSchema schema) =>
Expand Down
89 changes: 87 additions & 2 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ private class BreastCancerMulticlassExample
[LoadColumn(2, 9), VectorType(8)]
public float[] Features;
}
private class BreastCancerBinaryClassification
{
[LoadColumn(0)]
public bool Label;

[LoadColumn(2, 9), VectorType(8)]
public float[] Features;
}

[LessThanNetCore30OrNotNetCoreFact("netcoreapp3.0 output differs from Baseline. Tracked by https://github.com/dotnet/machinelearning/issues/2087")]
public void KmeansOnnxConversionTest()
Expand Down Expand Up @@ -187,6 +195,55 @@ public void KmeansOnnxConversionTest()
Done();
}

[Fact]
public void binaryClassificationTrainersOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
IEstimator<ITransformer>[] estimators = {
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LightGbm(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator);
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{estimator.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3);
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult);
}

}
Done();
}
private class DataPoint
{
[VectorType(3)]
Expand Down Expand Up @@ -853,7 +910,8 @@ private void CreateDummyExamplesToMakeComplierHappy()
var dummyExample = new BreastCancerFeatureVector() { Features = null };
var dummyExample1 = new BreastCancerCatFeatureExample() { Label = false, F1 = 0, F2 = "Amy" };
var dummyExample2 = new BreastCancerMulticlassExample() { Label = "Amy", Features = null };
var dummyExample3 = new SmallSentimentExample() { Tokens = null };
var dummyExample3 = new BreastCancerBinaryClassification() { Label = false, Features = null };
var dummyExample4 = new SmallSentimentExample() { Tokens = null };
}

private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
Expand Down Expand Up @@ -984,7 +1042,34 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC

// Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
Assert.Equal(1, actual.Length);
Assert.Equal(expected, actual.GetItemOrDefault(0), precision);
CompareNumbersWithTolerance(expected, actual.GetItemOrDefault(0), null, precision);
}
}
}
private void CompareSelectedScalarColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right)
{
var leftColumn = left.Schema[leftColumnName];
var rightColumn = right.Schema[rightColumnName];

using (var expectedCursor = left.GetRowCursor(leftColumn))
using (var actualCursor = right.GetRowCursor(rightColumn))
{
T expected = default;
VBuffer<T> actual = default;
var expectedGetter = expectedCursor.GetGetter<T>(leftColumn);
var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn);
while (expectedCursor.MoveNext() && actualCursor.MoveNext())
{
expectedGetter(ref expected);
actualGetter(ref actual);
var actualVal = actual.GetItemOrDefault(0);

Assert.Equal(1, actual.Length);

if (typeof(T) == typeof(ReadOnlyMemory<Char>))
Assert.Equal(expected.ToString(), actualVal.ToString());
else
Assert.Equal(expected, actualVal);
}
}
}
Expand Down

0 comments on commit ea71828

Please sign in to comment.