Skip to content
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0], slopVar }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Add";
var betaVar = ctx.AddInitializer(-0.0000001f, "Slope");
var betaVar = ctx.AddInitializer((float)(-Offset), "Offset");
var linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
node = ctx.CreateNode(opType, new[] { mulNodeOutput, betaVar }, new[] { linearOutput }, ctx.GetNodeName(opType), "");

Expand Down
19 changes: 18 additions & 1 deletion src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;

Expand Down Expand Up @@ -220,13 +221,15 @@ private protected VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(CalibratorTransformer<>).Assembly.FullName);
}

private sealed class Mapper<TCalibrator> : MapperBase
private sealed class Mapper<TCalibrator> : MapperBase, ISaveAsOnnx
where TCalibrator : class, ICalibrator
{
private TCalibrator _calibrator;
private readonly int _scoreColIndex;
private CalibratorTransformer<TCalibrator> _parent;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;

internal Mapper(CalibratorTransformer<TCalibrator> parent, TCalibrator calibrator, DataViewSchema inputSchema) :
base(parent.Host, inputSchema, parent)
{
Expand All @@ -243,6 +246,20 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a

private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
{
var scoreName = InputSchema[_scoreColIndex].Name;
var probabilityName = GetOutputColumnsCore()[0].Name;
Host.CheckValue(ctx, nameof(ctx));
if (_calibrator is ISingleCanSaveOnnx onnx)
{
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
scoreName = ctx.GetVariableName(scoreName);
probabilityName = ctx.AddIntermediateVariable(NumberDataViewType.Single, probabilityName);
onnx.SaveAsOnnx(ctx, new[] { scoreName, probabilityName }, ""); // No need for featureColumn
}
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var builder = new DataViewSchema.Annotations.Builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@
{
"input": [
"MulNodeOutput",
"Slope0"
"Offset"
],
"output": [
"linearOutput"
Expand Down Expand Up @@ -489,9 +489,9 @@
{
"dataType": 1,
"floatData": [
-1E-07
0
],
"name": "Slope0"
"name": "Offset"
}
],
"input": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@
{
"input": [
"MulNodeOutput",
"Slope0"
"Offset"
],
"output": [
"linearOutput"
Expand Down Expand Up @@ -815,9 +815,9 @@
{
"dataType": 1,
"floatData": [
-1E-07
0
],
"name": "Slope0"
"name": "Offset"
}
],
"input": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
{
"input": [
"MulNodeOutput",
"Slope0"
"Offset"
],
"output": [
"linearOutput"
Expand Down Expand Up @@ -482,9 +482,9 @@
{
"dataType": 1,
"floatData": [
-1E-07
0
],
"name": "Slope0"
"name": "Offset"
}
],
"input": [
Expand Down
103 changes: 103 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,109 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var model = pipeline.Fit(dataView);
var outputSchema = model.GetOutputSchema(dataView.Schema);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{estimator.ToString()}-WithPlattCalibrator.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels
CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities
}
}
Done();
}

class PlattModelInput
{
public bool Label { get; set; }
public float Score { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var calibratorEstimator = mlContext.BinaryClassification.Calibrators
.Platt();

var calibratorTransformer = calibratorEstimator.Fit(data);
var transformedData = calibratorTransformer.Transform(data);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(calibratorTransformer, data);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{calibratorTransformer.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities
}
Done();
}

private class DataPoint
{
[VectorType(3)]
Expand Down