@@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
261261 Done ( ) ;
262262 }
263263
264- [ Fact ]
265- public void PlattCalibratorOnnxConversionTest ( )
264+ private ( IDataView , List < IEstimator < ITransformer > > , EstimatorChain < NormalizingTransformer > ) GetEstimatorsForOnnxConversionTests ( )
266265 {
267- var mlContext = new MLContext ( seed : 1 ) ;
268- string dataPath = GetDataPath ( TestDatasets . breastCancer . trainFilename ) ;
266+ string dataPath = GetDataPath ( "breast-cancer.txt" ) ;
269267 // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
270- var dataView = mlContext . Data . LoadFromTextFile < BreastCancerBinaryClassification > ( dataPath , separatorChar : '\t ' , hasHeader : false ) ;
268+ var dataView = ML . Data . LoadFromTextFile < BreastCancerBinaryClassification > ( dataPath , separatorChar : '\t ' , hasHeader : true ) ;
271269 List < IEstimator < ITransformer > > estimators = new List < IEstimator < ITransformer > > ( )
272270 {
273- mlContext . BinaryClassification . Trainers . AveragedPerceptron ( ) ,
274- mlContext . BinaryClassification . Trainers . FastForest ( ) ,
275- mlContext . BinaryClassification . Trainers . FastTree ( ) ,
276- mlContext . BinaryClassification . Trainers . LbfgsLogisticRegression ( ) ,
277- mlContext . BinaryClassification . Trainers . LinearSvm ( ) ,
278- mlContext . BinaryClassification . Trainers . Prior ( ) ,
279- mlContext . BinaryClassification . Trainers . SdcaLogisticRegression ( ) ,
280- mlContext . BinaryClassification . Trainers . SdcaNonCalibrated ( ) ,
281- mlContext . BinaryClassification . Trainers . SgdCalibrated ( ) ,
282- mlContext . BinaryClassification . Trainers . SgdNonCalibrated ( ) ,
283- mlContext . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
271+ ML . BinaryClassification . Trainers . AveragedPerceptron ( ) ,
272+ ML . BinaryClassification . Trainers . FastForest ( ) ,
273+ ML . BinaryClassification . Trainers . FastTree ( ) ,
274+ ML . BinaryClassification . Trainers . LbfgsLogisticRegression ( ) ,
275+ ML . BinaryClassification . Trainers . LinearSvm ( ) ,
276+ ML . BinaryClassification . Trainers . Prior ( ) ,
277+ ML . BinaryClassification . Trainers . SdcaLogisticRegression ( ) ,
278+ ML . BinaryClassification . Trainers . SdcaNonCalibrated ( ) ,
279+ ML . BinaryClassification . Trainers . SgdCalibrated ( ) ,
280+ ML . BinaryClassification . Trainers . SgdNonCalibrated ( ) ,
281+ ML . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
284282 } ;
285283 if ( Environment . Is64BitProcess )
286284 {
287- estimators . Add ( mlContext . BinaryClassification . Trainers . LightGbm ( ) ) ;
285+ estimators . Add ( ML . BinaryClassification . Trainers . LightGbm ( ) ) ;
288286 }
289287
290- var initialPipeline = mlContext . Transforms . ReplaceMissingValues ( "Features" ) .
291- Append ( mlContext . Transforms . NormalizeMinMax ( "Features" ) ) ;
288+ var initialPipeline = ML . Transforms . ReplaceMissingValues ( "Features" ) .
289+ Append ( ML . Transforms . NormalizeMinMax ( "Features" ) ) ;
290+ return ( dataView , estimators , initialPipeline ) ;
291+ }
292+
293+ private void CommonCalibratorOnnxConversionTest ( IEstimator < ITransformer > calibrator , IEstimator < ITransformer > calibratorNonStandard )
294+ {
295+ // Initialize variables needed for the ONNX conversion test
296+ var ( dataView , estimators , initialPipeline ) = GetEstimatorsForOnnxConversionTests ( ) ;
297+
298+ // Step 1: Test calibrator with binary prediction trainer
292299 foreach ( var estimator in estimators )
293300 {
294- var pipeline = initialPipeline . Append ( estimator ) . Append ( mlContext . BinaryClassification . Calibrators . Platt ( ) ) ;
295- var onnxFileName = $ "{ estimator } -WithPlattCalibrator.onnx";
296-
297- TestPipeline ( pipeline , dataView , onnxFileName , new ColumnComparison [ ] { new ColumnComparison ( "Score" , 3 ) , new ColumnComparison ( "PredictedLabel" ) , new ColumnComparison ( "Probability" , 3 ) } ) ;
301+ var pipelineEstimators = initialPipeline . Append ( estimator ) . Append ( calibrator ) ;
302+ var onnxFileName = $ "{ estimator } -With-{ calibrator } .onnx";
303+ TestPipeline ( pipelineEstimators , dataView , onnxFileName , new ColumnComparison [ ] { new ColumnComparison ( "Score" , 3 ) , new ColumnComparison ( "PredictedLabel" ) , new ColumnComparison ( "Probability" , 3 ) } ) ;
298304 }
305+
306+ // Step 2: Test calibrator without any binary prediction trainer
307+ IDataView dataSoloCalibrator = ML . Data . LoadFromEnumerable ( GetCalibratorTestData ( ) ) ;
308+ var onnxFileNameSoloCalibrator = $ "{ calibrator } -SoloCalibrator.onnx";
309+ TestPipeline ( calibrator , dataSoloCalibrator , onnxFileNameSoloCalibrator , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
310+
311+ // Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
312+ IDataView dataSoloCalibratorNonStandard = ML . Data . LoadFromEnumerable ( GetCalibratorTestDataNonStandard ( ) ) ;
313+ var onnxFileNameSoloCalibratorNonStandard = $ "{ calibratorNonStandard } -SoloCalibrator-NonStandard.onnx";
314+ TestPipeline ( calibratorNonStandard , dataSoloCalibratorNonStandard , onnxFileNameSoloCalibratorNonStandard , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
315+
299316 Done ( ) ;
300317 }
301318
302- class PlattModelInput
319+ [ Fact ]
320+ public void PlattCalibratorOnnxConversionTest ( )
321+ {
322+ CommonCalibratorOnnxConversionTest ( ML . BinaryClassification . Calibrators . Platt ( ) ,
323+ ML . BinaryClassification . Calibrators . Platt ( scoreColumnName : "ScoreX" ) ) ;
324+ }
325+
326+ [ Fact ]
327+ public void FixedPlattCalibratorOnnxConversionTest ( )
328+ {
329+ // Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
330+ CommonCalibratorOnnxConversionTest ( ML . BinaryClassification . Calibrators . Platt ( slope : - 1f , offset : - 0.05f ) ,
331+ ML . BinaryClassification . Calibrators . Platt ( slope : - 1f , offset : - 0.05f , scoreColumnName : "ScoreX" ) ) ;
332+ }
333+
334+ [ Fact ]
335+ public void NaiveCalibratorOnnxConversionTest ( )
336+ {
337+ CommonCalibratorOnnxConversionTest ( ML . BinaryClassification . Calibrators . Naive ( ) ,
338+ ML . BinaryClassification . Calibrators . Naive ( scoreColumnName : "ScoreX" ) ) ;
339+ }
340+
341+ class CalibratorInput
303342 {
304343 public bool Label { get ; set ; }
305344 public float Score { get ; set ; }
306345 }
307346
308- class PlattModelInput2
347+ class CalibratorInputNonStandard
309348 {
310349 public bool Label { get ; set ; }
311350 public float ScoreX { get ; set ; }
312351 }
313352
314- static IEnumerable < PlattModelInput > PlattGetData ( )
353+ static IEnumerable < CalibratorInput > GetCalibratorTestData ( )
315354 {
316355 for ( int i = 0 ; i < 100 ; i ++ )
317356 {
318- yield return new PlattModelInput { Score = i , Label = i % 2 == 0 } ;
357+ yield return new CalibratorInput { Score = i , Label = i % 2 == 0 } ;
319358 }
320359 }
321360
322- static IEnumerable < PlattModelInput2 > PlattGetData2 ( )
361+ static IEnumerable < CalibratorInputNonStandard > GetCalibratorTestDataNonStandard ( )
323362 {
324363 for ( int i = 0 ; i < 100 ; i ++ )
325364 {
326- yield return new PlattModelInput2 { ScoreX = i , Label = i % 2 == 0 } ;
365+ yield return new CalibratorInputNonStandard { ScoreX = i , Label = i % 2 == 0 } ;
327366 }
328367 }
329368
330- [ Fact ]
331- public void PlattCalibratorOnnxConversionTest2 ( )
332- {
333- // Test PlattCalibrator without any binary prediction trainer
334- var mlContext = new MLContext ( seed : 0 ) ;
335-
336- IDataView data = mlContext . Data . LoadFromEnumerable ( PlattGetData ( ) ) ;
337-
338- var pipeline = mlContext . BinaryClassification . Calibrators
339- . Platt ( ) ;
340- var onnxFileName = $ "{ pipeline } .onnx";
341-
342- TestPipeline ( pipeline , data , onnxFileName , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
343-
344- // Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
345- IDataView data2 = mlContext . Data . LoadFromEnumerable ( PlattGetData2 ( ) ) ;
346-
347- var pipeline2 = mlContext . BinaryClassification . Calibrators
348- . Platt ( scoreColumnName : "ScoreX" ) ;
349- var onnxFileName2 = $ "{ pipeline2 } .onnx";
350-
351- TestPipeline ( pipeline2 , data2 , onnxFileName2 , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
352-
353- Done ( ) ;
354- }
355-
356369 [ Fact ]
357370 public void TextNormalizingOnnxConversionTest ( )
358371 {
0 commit comments