Skip to content

Commit fab2e1d

Browse files
authored
Added ONNX export support and tests for {FixedPlatt, Naive}CalibratorEstimators (#5289)
* Added ONNX export tests for other calibrators * Consolitated testing, started ONNX model conversion * Added ONNX export support for NaiveCalibrator * Enable NaiveCalibratorOnnxConversionTest * Work on Isotonic Calibrator ONNX export support * Removed Isotonic work from this PR * Nit correct spacing * Nit renaming to CalibratorInput(NonStandard) * Organized tests * Nit * Clean-up initialization of vars for CommonCalibratorOnnxConversionTest * Removed MLContexts for ML
1 parent 7879849 commit fab2e1d

File tree

2 files changed

+114
-56
lines changed

2 files changed

+114
-56
lines changed

src/Microsoft.ML.Data/Prediction/Calibrator.cs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1158,7 +1158,7 @@ ICalibrator ICalibratorTrainer.FinishTraining(IChannel ch)
11581158
/// <summary>
11591159
/// The naive binning-based calibrator.
11601160
/// </summary>
1161-
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat
1161+
public sealed class NaiveCalibrator : ICalibrator, ICanSaveInBinaryFormat, ISingleCanSaveOnnx
11621162
{
11631163
internal const string LoaderSignature = "NaiveCaliExec";
11641164
internal const string RegistrationName = "NaiveCalibrator";
@@ -1174,6 +1174,12 @@ private static VersionInfo GetVersionInfo()
11741174
loaderAssemblyName: typeof(NaiveCalibrator).Assembly.FullName);
11751175
}
11761176

1177+
/// <summary>
1178+
/// Bool required by the interface ISingleCanSaveOnnx, returns true if
1179+
/// and only if calibrator can be exported in ONNX.
1180+
/// </summary>
1181+
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
1182+
11771183
private readonly IHost _host;
11781184

11791185
/// <summary> The bin size.</summary>
@@ -1280,6 +1286,45 @@ internal static int GetBinIdx(float output, float min, float binSize, int numBin
12801286
return binIdx;
12811287
}
12821288

1289+
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
1290+
{
1291+
_host.CheckValue(ctx, nameof(ctx));
1292+
_host.CheckValue(outputNames, nameof(outputNames));
1293+
_host.Check(Utils.Size(outputNames) == 2);
1294+
// outputNames[0] refers to the name of the Score column, which is the input of this graph
1295+
// outputNames[1] refers to the name of the Probability column, which is the final output of this graph
1296+
1297+
const int minimumOpSetVersion = 9;
1298+
ctx.CheckOpSetVersion(minimumOpSetVersion, "NaiveCalibrator");
1299+
1300+
string opType = "Sub";
1301+
var minVar = ctx.AddInitializer(Min, "Min");
1302+
var subNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "subNodeOutput");
1303+
var node = ctx.CreateNode(opType, new[] { outputNames[0], minVar }, new[] { subNodeOutput }, ctx.GetNodeName(opType), "");
1304+
1305+
opType = "Div";
1306+
var binSizeVar = ctx.AddInitializer(BinSize, "BinSize");
1307+
var divNodeOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "binIndexOutput");
1308+
node = ctx.CreateNode(opType, new[] { subNodeOutput, binSizeVar }, new[] { divNodeOutput }, ctx.GetNodeName(opType), "");
1309+
1310+
opType = "Cast";
1311+
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "castOutput");
1312+
node = ctx.CreateNode(opType, divNodeOutput, castOutput, ctx.GetNodeName(opType), "");
1313+
var toTypeInt = typeof(long);
1314+
node.AddAttribute("to", toTypeInt);
1315+
1316+
opType = "Clip";
1317+
var zeroVar = ctx.AddInitializer(0, "Zero");
1318+
var numBinsMinusOneVar = ctx.AddInitializer(_binProbs.Length-1, "NumBinsMinusOne");
1319+
var binIndexOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "binIndexOutput");
1320+
node = ctx.CreateNode(opType, new[] { castOutput, zeroVar, numBinsMinusOneVar }, new[] { binIndexOutput }, ctx.GetNodeName(opType), "");
1321+
1322+
opType = "GatherElements";
1323+
var binProbabilitiesVar = ctx.AddInitializer(_binProbs, new long[] { _binProbs.Length, 1 }, "BinProbabilities");
1324+
node = ctx.CreateNode(opType, new[] { binProbabilitiesVar, binIndexOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
1325+
1326+
return true;
1327+
}
12831328
}
12841329

12851330
/// <summary>

test/Microsoft.ML.Tests/OnnxConversionTest.cs

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
261261
Done();
262262
}
263263

264-
[Fact]
265-
public void PlattCalibratorOnnxConversionTest()
264+
private (IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests()
266265
{
267-
var mlContext = new MLContext(seed: 1);
268-
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename);
266+
string dataPath = GetDataPath("breast-cancer.txt");
269267
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
270-
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: false);
268+
var dataView = ML.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
271269
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
272270
{
273-
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
274-
mlContext.BinaryClassification.Trainers.FastForest(),
275-
mlContext.BinaryClassification.Trainers.FastTree(),
276-
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
277-
mlContext.BinaryClassification.Trainers.LinearSvm(),
278-
mlContext.BinaryClassification.Trainers.Prior(),
279-
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
280-
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
281-
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
282-
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
283-
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
271+
ML.BinaryClassification.Trainers.AveragedPerceptron(),
272+
ML.BinaryClassification.Trainers.FastForest(),
273+
ML.BinaryClassification.Trainers.FastTree(),
274+
ML.BinaryClassification.Trainers.LbfgsLogisticRegression(),
275+
ML.BinaryClassification.Trainers.LinearSvm(),
276+
ML.BinaryClassification.Trainers.Prior(),
277+
ML.BinaryClassification.Trainers.SdcaLogisticRegression(),
278+
ML.BinaryClassification.Trainers.SdcaNonCalibrated(),
279+
ML.BinaryClassification.Trainers.SgdCalibrated(),
280+
ML.BinaryClassification.Trainers.SgdNonCalibrated(),
281+
ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
284282
};
285283
if (Environment.Is64BitProcess)
286284
{
287-
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
285+
estimators.Add(ML.BinaryClassification.Trainers.LightGbm());
288286
}
289287

290-
var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
291-
Append(mlContext.Transforms.NormalizeMinMax("Features"));
288+
var initialPipeline = ML.Transforms.ReplaceMissingValues("Features").
289+
Append(ML.Transforms.NormalizeMinMax("Features"));
290+
return (dataView, estimators, initialPipeline);
291+
}
292+
293+
private void CommonCalibratorOnnxConversionTest(IEstimator<ITransformer> calibrator, IEstimator<ITransformer> calibratorNonStandard)
294+
{
295+
// Initialize variables needed for the ONNX conversion test
296+
var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests();
297+
298+
// Step 1: Test calibrator with binary prediction trainer
292299
foreach (var estimator in estimators)
293300
{
294-
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
295-
var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx";
296-
297-
TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
301+
var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator);
302+
var onnxFileName = $"{estimator}-With-{calibrator}.onnx";
303+
TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) });
298304
}
305+
306+
// Step 2: Test calibrator without any binary prediction trainer
307+
IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData());
308+
var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx";
309+
TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
310+
311+
// Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
312+
IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard());
313+
var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx";
314+
TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
315+
299316
Done();
300317
}
301318

302-
class PlattModelInput
319+
[Fact]
320+
public void PlattCalibratorOnnxConversionTest()
321+
{
322+
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(),
323+
ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX"));
324+
}
325+
326+
[Fact]
327+
public void FixedPlattCalibratorOnnxConversionTest()
328+
{
329+
// Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
330+
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f),
331+
ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX"));
332+
}
333+
334+
[Fact]
335+
public void NaiveCalibratorOnnxConversionTest()
336+
{
337+
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(),
338+
ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX"));
339+
}
340+
341+
class CalibratorInput
303342
{
304343
public bool Label { get; set; }
305344
public float Score { get; set; }
306345
}
307346

308-
class PlattModelInput2
347+
class CalibratorInputNonStandard
309348
{
310349
public bool Label { get; set; }
311350
public float ScoreX { get; set; }
312351
}
313352

314-
static IEnumerable<PlattModelInput> PlattGetData()
353+
static IEnumerable<CalibratorInput> GetCalibratorTestData()
315354
{
316355
for (int i = 0; i < 100; i++)
317356
{
318-
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
357+
yield return new CalibratorInput { Score = i, Label = i % 2 == 0 };
319358
}
320359
}
321360

322-
static IEnumerable<PlattModelInput2> PlattGetData2()
361+
static IEnumerable<CalibratorInputNonStandard> GetCalibratorTestDataNonStandard()
323362
{
324363
for (int i = 0; i < 100; i++)
325364
{
326-
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 };
365+
yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 };
327366
}
328367
}
329368

330-
[Fact]
331-
public void PlattCalibratorOnnxConversionTest2()
332-
{
333-
// Test PlattCalibrator without any binary prediction trainer
334-
var mlContext = new MLContext(seed: 0);
335-
336-
IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());
337-
338-
var pipeline = mlContext.BinaryClassification.Calibrators
339-
.Platt();
340-
var onnxFileName = $"{pipeline}.onnx";
341-
342-
TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
343-
344-
// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
345-
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2());
346-
347-
var pipeline2 = mlContext.BinaryClassification.Calibrators
348-
.Platt(scoreColumnName: "ScoreX");
349-
var onnxFileName2 = $"{pipeline2}.onnx";
350-
351-
TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) });
352-
353-
Done();
354-
}
355-
356369
[Fact]
357370
public void TextNormalizingOnnxConversionTest()
358371
{

0 commit comments

Comments
 (0)