-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added ONNX export support and tests for {FixedPlatt, Naive}CalibratorEstimators #5289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f53e98f
b27d278
0a9a115
ef769d6
df4af98
8211621
a245b43
711a9a5
5acc93b
5d2e930
3a8889b
25ce2da
89b6203
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest() | |
Done(); | ||
} | ||
|
||
[Fact] | ||
public void PlattCalibratorOnnxConversionTest() | ||
private (IDataView, List<IEstimator<ITransformer>>, EstimatorChain<NormalizingTransformer>) GetEstimatorsForOnnxConversionTests() | ||
{ | ||
var mlContext = new MLContext(seed: 1); | ||
string dataPath = GetDataPath(TestDatasets.breastCancer.trainFilename); | ||
string dataPath = GetDataPath("breast-cancer.txt"); | ||
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). | ||
var dataView = mlContext.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: false); | ||
var dataView = ML.Data.LoadFromTextFile<BreastCancerBinaryClassification>(dataPath, separatorChar: '\t', hasHeader: true); | ||
List<IEstimator<ITransformer>> estimators = new List<IEstimator<ITransformer>>() | ||
{ | ||
mlContext.BinaryClassification.Trainers.AveragedPerceptron(), | ||
mlContext.BinaryClassification.Trainers.FastForest(), | ||
mlContext.BinaryClassification.Trainers.FastTree(), | ||
mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(), | ||
mlContext.BinaryClassification.Trainers.LinearSvm(), | ||
mlContext.BinaryClassification.Trainers.Prior(), | ||
mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(), | ||
mlContext.BinaryClassification.Trainers.SdcaNonCalibrated(), | ||
mlContext.BinaryClassification.Trainers.SgdCalibrated(), | ||
mlContext.BinaryClassification.Trainers.SgdNonCalibrated(), | ||
mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), | ||
ML.BinaryClassification.Trainers.AveragedPerceptron(), | ||
ML.BinaryClassification.Trainers.FastForest(), | ||
ML.BinaryClassification.Trainers.FastTree(), | ||
ML.BinaryClassification.Trainers.LbfgsLogisticRegression(), | ||
ML.BinaryClassification.Trainers.LinearSvm(), | ||
ML.BinaryClassification.Trainers.Prior(), | ||
ML.BinaryClassification.Trainers.SdcaLogisticRegression(), | ||
ML.BinaryClassification.Trainers.SdcaNonCalibrated(), | ||
ML.BinaryClassification.Trainers.SgdCalibrated(), | ||
ML.BinaryClassification.Trainers.SgdNonCalibrated(), | ||
ML.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(), | ||
}; | ||
if (Environment.Is64BitProcess) | ||
{ | ||
estimators.Add(mlContext.BinaryClassification.Trainers.LightGbm()); | ||
estimators.Add(ML.BinaryClassification.Trainers.LightGbm()); | ||
} | ||
|
||
var initialPipeline = mlContext.Transforms.ReplaceMissingValues("Features"). | ||
Append(mlContext.Transforms.NormalizeMinMax("Features")); | ||
var initialPipeline = ML.Transforms.ReplaceMissingValues("Features"). | ||
Append(ML.Transforms.NormalizeMinMax("Features")); | ||
return (dataView, estimators, initialPipeline); | ||
} | ||
|
||
private void CommonCalibratorOnnxConversionTest(IEstimator<ITransformer> calibrator, IEstimator<ITransformer> calibratorNonStandard) | ||
{ | ||
// Initialize variables needed for the ONNX conversion test | ||
var (dataView, estimators, initialPipeline) = GetEstimatorsForOnnxConversionTests(); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, you might have just introduced a bug here. The mlContext returned from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for the heads up, I have removed the usage of |
||
// Step 1: Test calibrator with binary prediction trainer | ||
foreach (var estimator in estimators) | ||
{ | ||
var pipeline = initialPipeline.Append(estimator).Append(mlContext.BinaryClassification.Calibrators.Platt()); | ||
var onnxFileName = $"{estimator}-WithPlattCalibrator.onnx"; | ||
|
||
TestPipeline(pipeline, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) }); | ||
var pipelineEstimators = initialPipeline.Append(estimator).Append(calibrator); | ||
var onnxFileName = $"{estimator}-With-{calibrator}.onnx"; | ||
TestPipeline(pipelineEstimators, dataView, onnxFileName, new ColumnComparison[] { new ColumnComparison("Score", 3), new ColumnComparison("PredictedLabel"), new ColumnComparison("Probability", 3) }); | ||
} | ||
|
||
// Step 2: Test calibrator without any binary prediction trainer | ||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
IDataView dataSoloCalibrator = ML.Data.LoadFromEnumerable(GetCalibratorTestData()); | ||
var onnxFileNameSoloCalibrator = $"{calibrator}-SoloCalibrator.onnx"; | ||
TestPipeline(calibrator, dataSoloCalibrator, onnxFileNameSoloCalibrator, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); | ||
|
||
// Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer | ||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
IDataView dataSoloCalibratorNonStandard = ML.Data.LoadFromEnumerable(GetCalibratorTestDataNonStandard()); | ||
var onnxFileNameSoloCalibratorNonStandard = $"{calibratorNonStandard}-SoloCalibrator-NonStandard.onnx"; | ||
TestPipeline(calibratorNonStandard, dataSoloCalibratorNonStandard, onnxFileNameSoloCalibratorNonStandard, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); | ||
|
||
Done(); | ||
} | ||
|
||
class PlattModelInput | ||
[Fact] | ||
public void PlattCalibratorOnnxConversionTest() | ||
{ | ||
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(), | ||
ML.BinaryClassification.Calibrators.Platt(scoreColumnName: "ScoreX")); | ||
} | ||
|
||
[Fact] | ||
public void FixedPlattCalibratorOnnxConversionTest() | ||
{ | ||
// Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values. | ||
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f), | ||
ML.BinaryClassification.Calibrators.Platt(slope: -1f, offset: -0.05f, scoreColumnName: "ScoreX")); | ||
} | ||
|
||
[Fact] | ||
public void NaiveCalibratorOnnxConversionTest() | ||
{ | ||
CommonCalibratorOnnxConversionTest(ML.BinaryClassification.Calibrators.Naive(), | ||
ML.BinaryClassification.Calibrators.Naive(scoreColumnName: "ScoreX")); | ||
} | ||
|
||
class CalibratorInput | ||
{ | ||
mstfbl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
public bool Label { get; set; } | ||
public float Score { get; set; } | ||
} | ||
|
||
class PlattModelInput2 | ||
class CalibratorInputNonStandard | ||
{ | ||
public bool Label { get; set; } | ||
public float ScoreX { get; set; } | ||
} | ||
|
||
static IEnumerable<PlattModelInput> PlattGetData() | ||
static IEnumerable<CalibratorInput> GetCalibratorTestData() | ||
{ | ||
for (int i = 0; i < 100; i++) | ||
{ | ||
yield return new PlattModelInput { Score = i, Label = i % 2 == 0 }; | ||
yield return new CalibratorInput { Score = i, Label = i % 2 == 0 }; | ||
} | ||
} | ||
|
||
static IEnumerable<PlattModelInput2> PlattGetData2() | ||
static IEnumerable<CalibratorInputNonStandard> GetCalibratorTestDataNonStandard() | ||
{ | ||
for (int i = 0; i < 100; i++) | ||
{ | ||
yield return new PlattModelInput2 { ScoreX = i, Label = i % 2 == 0 }; | ||
yield return new CalibratorInputNonStandard { ScoreX = i, Label = i % 2 == 0 }; | ||
} | ||
} | ||
|
||
[Fact] | ||
public void PlattCalibratorOnnxConversionTest2() | ||
{ | ||
// Test PlattCalibrator without any binary prediction trainer | ||
var mlContext = new MLContext(seed: 0); | ||
|
||
IDataView data = mlContext.Data.LoadFromEnumerable(PlattGetData()); | ||
|
||
var pipeline = mlContext.BinaryClassification.Calibrators | ||
.Platt(); | ||
var onnxFileName = $"{pipeline}.onnx"; | ||
|
||
TestPipeline(pipeline, data, onnxFileName, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); | ||
|
||
// Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer | ||
IDataView data2 = mlContext.Data.LoadFromEnumerable(PlattGetData2()); | ||
|
||
var pipeline2 = mlContext.BinaryClassification.Calibrators | ||
.Platt(scoreColumnName: "ScoreX"); | ||
var onnxFileName2 = $"{pipeline2}.onnx"; | ||
|
||
TestPipeline(pipeline2, data2, onnxFileName2, new ColumnComparison[] { new ColumnComparison("Probability", 3) }); | ||
|
||
Done(); | ||
} | ||
|
||
[Fact] | ||
public void TextNormalizingOnnxConversionTest() | ||
{ | ||
|
Uh oh!
There was an error while loading. Please reload this page.