Skip to content

Added ONNX export support and tests for {FixedPlatt, Naive}CalibratorEstimators #5289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch)
/// <summary>
/// The naive binning-based calibrator.
/// </summary>
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "NaiveCaliExec";
internal const string RegistrationName = "NaiveCalibrator";
Expand All @@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
}

/// <summary>
/// Bool required by the interface ISingleCanSaveOnnx, returns true if
/// and only if calibrator can be exported in ONNX.
/// </summary>
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

private readonly IHost _host;

/// <summary> The bin size.</summary>
Expand Down Expand Up @@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
return binIdx;
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
_host.CheckValue(ctx, nameof(ctx));
_host.CheckValue(outputNames, nameof(outputNames));
_host.Check(Utils.Size(outputNames) == 2);
// outputNames[0] refers to the name of the Score column, which is the input of this graph
// outputNames[1] refers to the name of the Probability column, which is the final output of this graph

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator");

string opType = "Sub";
var minVar = ctx.AddInitializer(Min, "Min");
var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput");
var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Div";
var binSizeVar = ctx.AddInitializer(BinSize, "BinSize");
var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput");
node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), "");
var toTypeInt = typeof(long);
node.AddAttribute("to", toTypeInt);

opType = "Clip";
var zeroVar = ctx.AddInitializer(0, "Zero");
var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne");
var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), "");

opType = "GatherElements";
var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities");
node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");

return true;
}
}

/// <summary>
Expand Down
123 changes: 68 additions & 55 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
private (IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: false);
var dataView = ML.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
ML.BinaryClassification.Trainers.AveragedPerceptron(),
ML.BinaryClassification.Trainers.FastForest(),
ML.BinaryClassification.Trainers.FastTree(),
ML.BinaryClassification.Trainers.LbfgsLogisticRegression(),
ML.BinaryClassification.Trainers.LinearSvm(),
ML.BinaryClassification.Trainers.Prior(),
ML.BinaryClassification.Trainers.SdcaLogisticRegression(),
ML.BinaryClassification.Trainers.SdcaNonCalibrated(),
ML.BinaryClassification.Trainers.SgdCalibrated(),
ML.BinaryClassification.Trainers.SgdNonCalibrated(),
ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
estimators.Add(ML.BinaryClassification.Trainers.LightGbm());
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
var initialPipeline = ML.Transforms.ReplaceMissingValues("Features").
Append(ML.Transforms.NormalizeMinMax("Features"));
return (dataView, estimators, initialPipeline);
}

private void CommonCalibratorOnnxConversionTest(IEstimator<ITransformer> calibrator, IEstimator<ITransformer> calibratorNonStandard)
{
// Initialize variables needed for the ONNX conversion test
var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, you might have just introduced a bug here. The mlContext returned from GetEstimatorsForOnnxConversionTests is now different from the MLContext used to create the calibrator. Either you should fix GetEstimatorsForOnnxConversionTests to use the MLContext from the base class or pass it in as a param to this function like I mentioned in my earlier comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the heads up, I have removed the usage of MLContext in these tests in lieu of ML.

// Step 1: Test calibrator with binary prediction trainer
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator);
var onnxFileName = $"{estimator}-With-{calibrator}.onnx";
TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}

// Step 2: Test calibrator without any binary prediction trainer
IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData());
var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx";
TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard());
var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx";
TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

class PlattModelInput
[Fact]
public void PlattCalibratorOnnxConversionTest()
{
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(),
ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX"));
}

[Fact]
public void FixedPlattCalibratorOnnxConversionTest()
{
// Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f),
ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX"));
}

[Fact]
public void NaiveCalibratorOnnxConversionTest()
{
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(),
ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
}

class CalibratorInput
{
public bool Label { get; set; }
public float Score { get; set; }
}

class PlattModelInput2
class CalibratorInputNonStandard
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
static IEnumerable<CalibratorInput> GetCalibratorTestData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
yield return new CalibratorInput { Score = i, Label = i % 2 == 0 };
}
}

static IEnumerable<PlattModelInput2> PlattGetData2()
static IEnumerable<CalibratorInputNonStandard> GetCalibratorTestDataNonStandard()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var pipeline = mlContext.BinaryClassification.Calibrators
.Platt();
var onnxFileName = $"{pipeline}.onnx";

TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());

var pipeline2 = mlContext.BinaryClassification.Calibrators
.Platt(scoreColumnName: "ScoreX");
var onnxFileName2 = $"{pipeline2}.onnx";

TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

[Fact]
public void TextNormalizingOnnxConversionTest()
{
Expand Down