Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Onnx Export to PlattCalibratorTransformer #4699

Merged
merged 10 commits into from
Jan 27, 2020
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] scoreProbablityColu
var node = ctx.CreateNode(opType, new[] { scoreProbablityColumnNames[0], slopVar }, new[] { mulNodeOutput }, ctx.GetNodeName(opType), "");

opType = "Add";
var betaVar = ctx.AddInitializer(-0.0000001f, "Slope");
var betaVar = ctx.AddInitializer((float)(-Offset), "Offset");
var linearOutput = ctx.AddIntermediateVariable(null, "linearOutput", true);
node = ctx.CreateNode(opType, new[] { mulNodeOutput, betaVar }, new[] { linearOutput }, ctx.GetNodeName(opType), "");

Expand Down
19 changes: 18 additions & 1 deletion src/Microsoft.ML.Data/Prediction/CalibratorCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML;
using Microsoft.ML.Calibrators;
using Microsoft.ML.Data;
using Microsoft.ML.Model.OnnxConverter;
using Microsoft.ML.Runtime;
using Microsoft.ML.Trainers;

Expand Down Expand Up @@ -220,13 +221,15 @@ private protected VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(CalibratorTransformer<>).Assembly.FullName);
}

private sealed class Mapper<TCalibrator> : MapperBase
private sealed class Mapper<TCalibrator> : MapperBase, ISaveAsOnnx
where TCalibrator : class, ICalibrator
{
private TCalibrator _calibrator;
private readonly int _scoreColIndex;
private CalibratorTransformer<TCalibrator> _parent;

bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _calibrator is ICanSaveOnnx onnxMapper ? onnxMapper.CanSaveOnnx(ctx) : false;

internal Mapper(CalibratorTransformer<TCalibrator> parent, TCalibrator calibrator, DataViewSchema inputSchema) :
base(parent.Host, inputSchema, parent)
{
Expand All @@ -243,6 +246,20 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a

private protected override void SaveModel(ModelSaveContext ctx) => _parent.SaveModel(ctx);

void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
{
var scoreName = InputSchema[_scoreColIndex].Name;
var probabilityName = GetOutputColumnsCore()[0].Name;
Host.CheckValue(ctx, nameof(ctx));
if (_calibrator is ISingleCanSaveOnnx onnx)
{
Host.Check(onnx.CanSaveOnnx(ctx), "Cannot be saved as ONNX.");
scoreName = ctx.GetVariableName(scoreName);
probabilityName = ctx.AddIntermediateVariable(NumberDataViewType.Single, probabilityName);
onnx.SaveAsOnnx(ctx, new[] { scoreName, probabilityName }, ""); // No need for featureColumn
}
}

protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var builder = new DataViewSchema.Annotations.Builder();
Expand Down
106 changes: 106 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,112 @@ public void BinaryClassificationTrainersOnnxConversionTest()
Done();
}

[Fact]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be part of your changes. Maybe you did your merge steps wrong?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think I messed up, but it's fixed now

public void PlattCalibratorOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron(),
mlContext.BinaryClassification.Trainers.FastForest(),
mlContext.BinaryClassification.Trainers.FastTree(),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(),
mlContext.BinaryClassification.Trainers.LinearSvm(),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(),
mlContext.BinaryClassification.Trainers.SgdCalibrated(),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm());
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features").
Append(mlContext.Transforms.NormalizeMinMax("Features"));
foreach (var estimator in estimators)
{
// var pipeline = initialPipeline.Append(estimator);
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt());
var model = pipeline.Fit(dataView);
var outputSchema = model.GetOutputSchema(dataView.Schema);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{estimator.ToString()}-WithPlattCalibrator.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedR4ScalarColumns(transformedData.Schema[5].Name, outputNames[3], transformedData, onnxResult, 3); //compare scores
CompareSelectedScalarColumns<Boolean>(transformedData.Schema[4].Name, outputNames[2], transformedData, onnxResult); //compare predicted labels
CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities
}
}
Done();
}

class PlattModelInput
{
public bool Label { get; set; }
public float Score { get; set; }
}

static IEnumerable<PlattModelInput> PlattGetData()
{
for(int i = 0; i < 100; i++)
{
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 };
}
}

[Fact]
public void PlattCalibratorOnnxConversionTest2()
{
// Test PlattCalibrator without any binary prediction trainer
var mlContext = new MLContext(seed: 0);

IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData());

var calibratorEstimator = mlContext.BinaryClassification.Calibrators
.Platt();

var calibratorTransformer = calibratorEstimator.Fit(data);
var transformedData = calibratorTransformer.Transform(data);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(calibratorTransformer, data);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
var onnxFileName = $"{calibratorTransformer.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(data);
var onnxResult = onnxTransformer.Transform(data);
CompareSelectedR4ScalarColumns(transformedData.Schema.Last().Name, outputNames.Last(), transformedData, onnxResult, 3); //compare probabilities
}

Done();

antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
}

private class DataPoint
{
[VectorType(3)]
Expand Down