Skip to content

Commit

Permalink
Added ONNX export support and tests for {FixedPlatt, Naive}Calibrator…
Browse files Browse the repository at this point in the history
…Estimators (#5289)

* Added ONNX export tests for other calibrators

* Consolitated testing, started ONNX model conversion

* Added ONNX export support for NaiveCalibrator

* Enable NaiveCalibratorOnnxConversionTest

* Work on Isotonic Calibrator ONNX export support

* Removed Isotonic work from this PR

* Nit correct spacing

* Nit renaming to CalibratorInput(NonStandard)

* Organized tests

* Nit

* Clean-up initialization of vars for CommonCalibratorOnnxConversionTest

* Removed MLContexts for ML
  • Loading branch information
mstfbl authored Jul 11, 2020
1 parent 7879849 commit fab2e1d
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 56 deletions.
47 changes: 46 additions & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch)
/// <summary>
/// The naive binning-based calibrator.
/// </summary>
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
{
internal const string LoaderSignature = "NaiveCaliExec";
internal const string RegistrationName = "NaiveCalibrator";
Expand All @@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
}

/// <summary>
/// Bool required by the interface ISingleCanSaveOnnx, returns true if
/// and only if calibrator can be exported in ONNX.
/// </summary>
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;

private readonly IHost _host;

/// <summary> The bin size.</summary>
Expand Down Expand Up @@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
return binIdx;
}

bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
{
_host.CheckValue(ctx, nameof(ctx));
_host.CheckValue(outputNames, nameof(outputNames));
_host.Check(Utils.Size(outputNames) == 2);
// outputNames[0] refers to the name of the Score column, which is the input of this graph
// outputNames[1] refers to the name of the Probability column, which is the final output of this graph

const int minimumOpSetVersion = 9;
ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator");

string opType = "Sub";
var minVar = ctx.AddInitializer(Min, "Min");
var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput");
var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Div";
var binSizeVar = ctx.AddInitializer(BinSize, "BinSize");
var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Cast";
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput");
node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), "");
var toTypeInt = typeof(long);
node.AddAttribute("to", toTypeInt);

opType = "Clip";
var zeroVar = ctx.AddInitializer(0, "Zero");
var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne");
var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput");
node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), "");

opType = "GatherElements";
var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities");
node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");

return true;
}
}

/// <summary>
Expand Down
123 changes: 68 additions & 55 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
private (IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: false);
var dataView = ML.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
ML.BinaryClassification.Trainers.AveragedPerceptron(),
ML.BinaryClassification.Trainers.FastForest(),
ML.BinaryClassification.Trainers.FastTree(),
ML.BinaryClassification.Trainers.LbfgsLogisticRegression(),
ML.BinaryClassification.Trainers.LinearSvm(),
ML.BinaryClassification.Trainers.Prior(),
ML.BinaryClassification.Trainers.SdcaLogisticRegression(),
ML.BinaryClassification.Trainers.SdcaNonCalibrated(),
ML.BinaryClassification.Trainers.SgdCalibrated(),
ML.BinaryClassification.Trainers.SgdNonCalibrated(),
ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
estimators.Add(ML.BinaryClassification.Trainers.LightGbm());
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
var initialPipeline = ML.Transforms.ReplaceMissingValues("Features").
Append(ML.Transforms.NormalizeMinMax("Features"));
return (dataView, estimators, initialPipeline);
}

private void CommonCalibratorOnnxConversionTest(IEstimator<ITransformer> calibrator, IEstimator<ITransformer> calibratorNonStandard)
{
// Initialize variables needed for the ONNX conversion test
var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();

// Step 1: Test calibrator with binary prediction trainer
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx";

TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator);
var onnxFileName = $"{estimator}-With-{calibrator}.onnx";
TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
}

// Step 2: Test calibrator without any binary prediction trainer
IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData());
var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx";
TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard());
var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx";
TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

class PlattModelInput
[Fact]
public void PlattCalibratorOnnxConversionTest()
{
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(),
ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX"));
}

[Fact]
public void FixedPlattCalibratorOnnxConversionTest()
{
// Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f),
ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX"));
}

[Fact]
public void NaiveCalibratorOnnxConversionTest()
{
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(),
ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
}

class CalibratorInput
{
public bool Label { get; set; }
public float Score { get; set; }
}

class PlattModelInput2
class CalibratorInputNonStandard
{
public bool Label { get; set; }
public float ScoreX { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
static IEnumerable<CalibratorInput> GetCalibratorTestData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
yield return new CalibratorInput { Score = i, Label = i % 2 == 0 };
}
}

static IEnumerable<PlattModelInput2> PlattGetData2()
static IEnumerable<CalibratorInputNonStandard> GetCalibratorTestDataNonStandard()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var pipeline = mlContext.BinaryClassification.Calibrators
.Platt();
var onnxFileName = $"{pipeline}.onnx";

TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());

var pipeline2 = mlContext.BinaryClassification.Calibrators
.Platt(scoreColumnName: "ScoreX");
var onnxFileName2 = $"{pipeline2}.onnx";

TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });

Done();
}

[Fact]
public void TextNormalizingOnnxConversionTest()
{
Expand Down

0 comments on commit fab2e1d

Please sign in to comment.