Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Onnx Export to PlattCalibratorTransformer #4699

Merged
merged 10 commits into from
Jan 27, 2020
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0], slopVar }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Add";
var betaVar = ctx.AddInitializer(-0.0000001f, "Slope");
var betaVar = ctx.AddInitializer((float)(-Offset), "Offset");
var linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
node = ctx.CreateNode(opType, new[] { mulNodeOutput, betaVar }, new[] { linearOutput }, ctx.GetNodeName(opType), "");

Expand Down
19 changes: 18 additions & 1 deletion src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;

Expand Down Expand Up @@ -220,13 +221,15 @@ private protected VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(CalibratorTransformer<>).Assembly.FullName);
}

private sealed class Mapper<TCalibrator> : MapperBase
private sealed class Mapper<TCalibrator> : MapperBase, ISaveAsOnnx
where TCalibrator : class, ICalibrator
{
private TCalibrator _calibrator;
private readonly int _scoreColIndex;
private CalibratorTransformer<TCalibrator> _parent;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;

internal Mapper(CalibratorTransformer<TCalibrator> parent, TCalibrator calibrator, DataViewSchema inputSchema) :
base(parent.Host, inputSchema, parent)
{
Expand All @@ -243,6 +246,20 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a

private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
{
var scoreName = InputSchema[_scoreColIndex].Name;
var probabilityName = GetOutputColumnsCore()[0].Name;
Host.CheckValue(ctx, nameof(ctx));
if (_calibrator is ISingleCanSaveOnnx onnx)
{
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
scoreName = ctx.GetVariableName(scoreName);
probabilityName = ctx.AddIntermediateVariable(NumberDataViewType.Single, probabilityName);
onnx.SaveAsOnnx(ctx, new[] { scoreName, probabilityName }, ""); // No need for featureColumn
}
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var builder = new DataViewSchema.Annotations.Builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@
{
"input": [
"MulNodeOutput",
"Slope0"
"Offset"
],
"output": [
"linearOutput"
Expand Down Expand Up @@ -489,9 +489,9 @@
{
"dataType": 1,
"floatData": [
-1E-07
0
],
"name": "Slope0"
"name": "Offset"
}
],
"input": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@
{
"input": [
"MulNodeOutput",
"Slope0"
"Offset"
],
"output": [
"linearOutput"
Expand Down Expand Up @@ -815,9 +815,9 @@
{
"dataType": 1,
"floatData": [
-1E-07
0
],
"name": "Slope0"
"name": "Offset"
}
],
"input": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@
{
"input": [
"MulNodeOutput",
"Slope0"
"Offset"
],
"output": [
"linearOutput"
Expand Down Expand Up @@ -482,9 +482,9 @@
{
"dataType": 1,
"floatData": [
-1E-07
0
],
"name": "Slope0"
"name": "Offset"
}
],
"input": [
Expand Down
103 changes: 103 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,109 @@ public void TestVectorWhiteningOnnxConversionTest()
Done();
}

[Fact]
public void PlattCalibratorOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var model = pipeline.Fit(dataView);
var outputSchema = model.GetOutputSchema(dataView.Schema);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{estimator.ToString()}-WithPlattCalibrator.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels
CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities
}
}
Done();
}

class PlattModelInput
{
public bool Label { get; set; }
public float Score { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
{
for (int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var calibratorEstimator = mlContext.BinaryClassification.Calibrators
.Platt();

var calibratorTransformer = calibratorEstimator.Fit(data);
var transformedData = calibratorTransformer.Transform(data);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(calibratorTransformer, data);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{calibratorTransformer.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities
}
Done();
}

private class DataPoint
{
[VectorType(3)]
Expand Down