Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ONNX export support and tests for {FixedPlatt, Naive}CalibratorEstimators #5289

Merged
merged 13 commits into from
Jul 11, 2020
61 changes: 56 additions & 5 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,7 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
private (MLContext, IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests()
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
Expand All @@ -289,6 +288,13 @@ public void PlattCalibratorOnnxConversionTest()

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
return (mlContext, dataView, estimators, initialPipeline);
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
{
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
Expand All @@ -299,7 +305,52 @@ public void PlattCalibratorOnnxConversionTest()
Done();
}

class PlattModelInput
[Fact]
public void FixedPlattCalibratorOnnxConversionTest()
{
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
foreach (var estimator in estimators)
{
// Utilize FixedPlattCalibrator by defining slope and offset
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f));
var onnxFileName = $"{estimator}-WithFixedPlattCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}
Done();
}

[Fact]
[Trait("Category", "SkipInCI")]
public void NaiveCalibratorOnnxConversionTest()
{
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Naive());
var onnxFileName = $"{estimator}-WithNaiveCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}
Done();
}

[Fact]
[Trait("Category", "SkipInCI")]
public void IsotonicCalibratorOnnxConversionTest()
{
var (mlContext, dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Isotonic());
var onnxFileName = $"{estimator}-WithIsotonicCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}
Done();
}

class ModelInput
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
{
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
public bool Label { get; set; }
public float Score { get; set; }
Expand All @@ -311,11 +362,11 @@ class PlattModelInput2
public float ScoreX { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
static IEnumerable<ModelInput> PlattGetData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
yield return new ModelInput { Score = i, Label = i % 2 == 0 };
}
}

Expand Down