Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ONNX export support and tests for {FixedPlatt, Naive}CalibratorEstimators #5289

Merged
merged 13 commits into from
Jul 11, 2020
47 changes: 46 additions & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch)
/// <summary>
/// The naive binning-based calibrator.
/// </summary>
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "NaiveCaliExec";
internal const string RegistrationName = "NaiveCalibrator";
Expand All @@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
}

/// <summary>
/// Bool required by the interface ISingleCanSaveOnnx, returns true if
/// and only if calibrator can be exported in ONNX.
/// </summary>
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

private readonly IHost _host;

/// <summary> The bin size.</summary>
Expand Down Expand Up @@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
return binIdx;
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
_host.CheckValue(ctx, nameof(ctx));
_host.CheckValue(outputNames, nameof(outputNames));
_host.Check(Utils.Size(outputNames) == 2);
// outputNames[0] refers to the name of the Score column, which is the input of this graph
// outputNames[1] refers to the name of the Probability column, which is the final output of this graph

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator");

string opType = "Sub";
var minVar = ctx.AddInitializer(Min, "Min");
var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput");
var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), "");

mstfbl marked this conversation as resolved.
Show resolved Hide resolved
opType = "Div";
var binSizeVar = ctx.AddInitializer(BinSize, "BinSize");
var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput");
node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), "");
var toTypeInt = typeof(long);
node.AddAttribute("to", toTypeInt);

mstfbl marked this conversation as resolved.
Show resolved Hide resolved
opType = "Clip";
var zeroVar = ctx.AddInitializer(0, "Zero");
var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne");
var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), "");

opType = "GatherElements";
var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities");
node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");

return true;
}
}

/// <summary>
Expand Down
123 changes: 68 additions & 55 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
private (IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: false);
var dataView = ML.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
ML.BinaryClassification.Trainers.AveragedPerceptron(),
ML.BinaryClassification.Trainers.FastForest(),
ML.BinaryClassification.Trainers.FastTree(),
ML.BinaryClassification.Trainers.LbfgsLogisticRegression(),
ML.BinaryClassification.Trainers.LinearSvm(),
ML.BinaryClassification.Trainers.Prior(),
ML.BinaryClassification.Trainers.SdcaLogisticRegression(),
ML.BinaryClassification.Trainers.SdcaNonCalibrated(),
ML.BinaryClassification.Trainers.SgdCalibrated(),
ML.BinaryClassification.Trainers.SgdNonCalibrated(),
ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
estimators.Add(ML.BinaryClassification.Trainers.LightGbm());
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
var initialPipeline = ML.Transforms.ReplaceMissingValues("Features").
Append(ML.Transforms.NormalizeMinMax("Features"));
return (dataView, estimators, initialPipeline);
}

private void CommonCalibratorOnnxConversionTest(IEstimator<ITransformer> calibrator, IEstimator<ITransformer> calibratorNonStandard)
{
// Initialize variables needed for the ONNX conversion test
var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, you might have just introduced a bug here. The mlContext returned from GetEstimatorsForOnnxConversionTests is now different from the MLContext used to create the calibrator. Either you should fix GetEstimatorsForOnnxConversionTests to use the MLContext from the base class or pass it in as a param to this function like I mentioned in my earlier comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the heads up, I have removed the usage of MLContext in these tests in lieu of ML.

// Step 1: Test calibrator with binary prediction trainer
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator);
var onnxFileName = $"{estimator}-With-{calibrator}.onnx";
TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}

// Step 2: Test calibrator without any binary prediction trainer
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData());
var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx";
TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard());
var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx";
TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

class PlattModelInput
[Fact]
public void PlattCalibratorOnnxConversionTest()
{
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(),
ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX"));
}

[Fact]
public void FixedPlattCalibratorOnnxConversionTest()
{
// Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f),
ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX"));
}

[Fact]
public void NaiveCalibratorOnnxConversionTest()
{
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(),
ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
}

class CalibratorInput
{
mstfbl marked this conversation as resolved.
Show resolved Hide resolved
public bool Label { get; set; }
public float Score { get; set; }
}

class PlattModelInput2
class CalibratorInputNonStandard
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
static IEnumerable<CalibratorInput> GetCalibratorTestData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
yield return new CalibratorInput { Score = i, Label = i % 2 == 0 };
}
}

static IEnumerable<PlattModelInput2> PlattGetData2()
static IEnumerable<CalibratorInputNonStandard> GetCalibratorTestDataNonStandard()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var pipeline = mlContext.BinaryClassification.Calibrators
.Platt();
var onnxFileName = $"{pipeline}.onnx";

TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());

var pipeline2 = mlContext.BinaryClassification.Calibrators
.Platt(scoreColumnName: "ScoreX");
var onnxFileName2 = $"{pipeline2}.onnx";

TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

[Fact]
public void TextNormalizingOnnxConversionTest()
{
Expand Down