Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniform onnx conversion method when using non-default column names #5146

Merged
merged 6 commits into from
May 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/Microsoft.ML.Data/Scorers/BinaryClassifierScorer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,6 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx)
Host.Assert(Bindable is IBindableCanSaveOnnx);
Host.Assert(Bindings.InfoCount >= 2);

if (!ctx.ContainsColumn(DefaultColumnNames.Features))
return;

base.SaveAsOnnxCore(ctx);
int delta = Bindings.DerivedColumnCount;

Expand Down
138 changes: 138 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1959,6 +1959,144 @@ public void SelectColumnsOnnxTest()
Done();
}

private class BreastCancerMulticlassExampleNonDefaultColNames
{
[LoadColumn(1)]
public string Label;

[LoadColumn(2, 9), VectorType(8)]
public float[] MyFeatureVector;
}

private class BreastCancerBinaryClassificationNonDefaultColNames
{
[LoadColumn(0)]
public bool Label;

[LoadColumn(2, 9), VectorType(8)]
public float[] MyFeatureVector;
}

[Fact]
public void NonDefaultColNamesBinaryClassificationOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);
string dataPath = GetDataPath("breast-cancer.txt");
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassificationNonDefaultColNames>(dataPath, separatorChar: '\t', hasHeader: true);
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.FastForest("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.FastTree("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.LinearSvm("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.Prior(),
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.SgdCalibrated("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.SgdNonCalibrated("Label", "MyFeatureVector"),
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression("Label", "MyFeatureVector"),
};
if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm("Label", "MyFeatureVector"));
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("MyFeatureVector").
Append(mlContext.Transforms.NormalizeMinMax("MyFeatureVector"));
foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator);
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);
var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);

var onnxFileName = $"{estimator.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);
SaveOnnxModel(onnxModel, onnxModelPath, null);

// Compare model scores produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedColumns<float>("Score", "Score", transformedData, onnxResult, 3); //compare scores
CompareSelectedColumns<bool>("PredictedLabel", "PredictedLabel", transformedData, onnxResult); //compare predicted labels
}
}
Done();
}

[Fact]
public void NonDefaultColNamesMultiClassificationOnnxConversionTest()
{
var mlContext = new MLContext(seed: 1);

string dataPath = GetDataPath("breast-cancer.txt");
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerMulticlassExampleNonDefaultColNames>(dataPath, separatorChar: '\t', hasHeader: true);

List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>()
{
mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy("Label", "MyFeatureVector"),
mlContext.MulticlassClassification.Trainers.NaiveBayes("Label", "MyFeatureVector"),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "MyFeatureVector")),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.AveragedPerceptron("Label", "MyFeatureVector"), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression("Label", "MyFeatureVector")),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression("Label", "MyFeatureVector"), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LinearSvm("Label", "MyFeatureVector")),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LinearSvm("Label", "MyFeatureVector"), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.FastForest("Label", "MyFeatureVector")),
mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.FastForest("Label", "MyFeatureVector"), useProbabilities:false),
mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "MyFeatureVector"),
mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated("Label", "MyFeatureVector")
};

if (Environment.Is64BitProcess)
{
estimators.Add(mlContext.MulticlassClassification.Trainers.LightGbm("Label", "MyFeatureVector"));
}

var initialPipeline = mlContext.Transforms.ReplaceMissingValues("MyFeatureVector")
.Append(mlContext.Transforms.NormalizeMinMax("MyFeatureVector"))
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"));

foreach (var estimator in estimators)
{
var pipeline = initialPipeline.Append(estimator);
var model = pipeline.Fit(dataView);
var transformedData = model.Transform(dataView);

var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView);
var onnxFileName = $"{estimator.ToString()}.onnx";
var onnxModelPath = GetOutputPath(onnxFileName);

SaveOnnxModel(onnxModel, onnxModelPath, null);

// Compare results produced by ML.NET and ONNX's runtime.
if (IsOnnxRuntimeSupported())
{
// Evaluate the saved ONNX model using the data used to train the ML.NET pipeline.
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(onnxModelPath);
var onnxTransformer = onnxEstimator.Fit(dataView);
var onnxResult = onnxTransformer.Transform(dataView);
CompareSelectedColumns<uint>("PredictedLabel", "PredictedLabel", transformedData, onnxResult);
CompareSelectedColumns<float>("Score", "Score", transformedData, onnxResult, 4);
}
}
Done();
}

private void CompareResults(string leftColumnName, string rightColumnName, IDataView left, IDataView right, int precision = 6)
{
var leftColumn = left.Schema[leftColumnName];
Expand Down