@@ -261,98 +261,111 @@ public void TestVectorWhiteningOnnxConversionTest()
261
261
Done ( ) ;
262
262
}
263
263
264
- [ Fact ]
265
- public void PlattCalibratorOnnxConversionTest ( )
264
+ private ( IDataView , List < IEstimator < ITransformer > > , EstimatorChain < NormalizingTransformer > ) GetEstimatorsForOnnxConversionTests ( )
266
265
{
267
- var mlContext = new MLContext ( seed : 1 ) ;
268
- string dataPath = GetDataPath ( TestDatasets . breastCancer . trainFilename ) ;
266
+ string dataPath = GetDataPath ( "breast-cancer.txt" ) ;
269
267
// Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed).
270
- var dataView = mlContext . Data . LoadFromTextFile < BreastCancerBinaryClassification > ( dataPath , separatorChar : '\t ' , hasHeader : false ) ;
268
+ var dataView = ML . Data . LoadFromTextFile < BreastCancerBinaryClassification > ( dataPath , separatorChar : '\t ' , hasHeader : true ) ;
271
269
List < IEstimator < ITransformer > > estimators = new List < IEstimator < ITransformer > > ( )
272
270
{
273
- mlContext . BinaryClassification . Trainers . AveragedPerceptron ( ) ,
274
- mlContext . BinaryClassification . Trainers . FastForest ( ) ,
275
- mlContext . BinaryClassification . Trainers . FastTree ( ) ,
276
- mlContext . BinaryClassification . Trainers . LbfgsLogisticRegression ( ) ,
277
- mlContext . BinaryClassification . Trainers . LinearSvm ( ) ,
278
- mlContext . BinaryClassification . Trainers . Prior ( ) ,
279
- mlContext . BinaryClassification . Trainers . SdcaLogisticRegression ( ) ,
280
- mlContext . BinaryClassification . Trainers . SdcaNonCalibrated ( ) ,
281
- mlContext . BinaryClassification . Trainers . SgdCalibrated ( ) ,
282
- mlContext . BinaryClassification . Trainers . SgdNonCalibrated ( ) ,
283
- mlContext . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
271
+ ML . BinaryClassification . Trainers . AveragedPerceptron ( ) ,
272
+ ML . BinaryClassification . Trainers . FastForest ( ) ,
273
+ ML . BinaryClassification . Trainers . FastTree ( ) ,
274
+ ML . BinaryClassification . Trainers . LbfgsLogisticRegression ( ) ,
275
+ ML . BinaryClassification . Trainers . LinearSvm ( ) ,
276
+ ML . BinaryClassification . Trainers . Prior ( ) ,
277
+ ML . BinaryClassification . Trainers . SdcaLogisticRegression ( ) ,
278
+ ML . BinaryClassification . Trainers . SdcaNonCalibrated ( ) ,
279
+ ML . BinaryClassification . Trainers . SgdCalibrated ( ) ,
280
+ ML . BinaryClassification . Trainers . SgdNonCalibrated ( ) ,
281
+ ML . BinaryClassification . Trainers . SymbolicSgdLogisticRegression ( ) ,
284
282
} ;
285
283
if ( Environment . Is64BitProcess )
286
284
{
287
- estimators . Add ( mlContext . BinaryClassification . Trainers . LightGbm ( ) ) ;
285
+ estimators . Add ( ML . BinaryClassification . Trainers . LightGbm ( ) ) ;
288
286
}
289
287
290
- var initialPipeline = mlContext . Transforms . ReplaceMissingValues ( "Features" ) .
291
- Append ( mlContext . Transforms . NormalizeMinMax ( "Features" ) ) ;
288
+ var initialPipeline = ML . Transforms . ReplaceMissingValues ( "Features" ) .
289
+ Append ( ML . Transforms . NormalizeMinMax ( "Features" ) ) ;
290
+ return ( dataView , estimators , initialPipeline ) ;
291
+ }
292
+
293
+ private void CommonCalibratorOnnxConversionTest ( IEstimator < ITransformer > calibrator , IEstimator < ITransformer > calibratorNonStandard )
294
+ {
295
+ // Initialize variables needed for the ONNX conversion test
296
+ var ( dataView , estimators , initialPipeline ) = GetEstimatorsForOnnxConversionTests ( ) ;
297
+
298
+ // Step 1: Test calibrator with binary prediction trainer
292
299
foreach ( var estimator in estimators )
293
300
{
294
- var pipeline = initialPipeline . Append ( estimator ) . Append ( mlContext . BinaryClassification . Calibrators . Platt ( ) ) ;
295
- var onnxFileName = $ "{ estimator } -WithPlattCalibrator.onnx";
296
-
297
- TestPipeline ( pipeline , dataView , onnxFileName , new ColumnComparison [ ] { new ColumnComparison ( "Score" , 3 ) , new ColumnComparison ( "PredictedLabel" ) , new ColumnComparison ( "Probability" , 3 ) } ) ;
301
+ var pipelineEstimators = initialPipeline . Append ( estimator ) . Append ( calibrator ) ;
302
+ var onnxFileName = $ "{ estimator } -With-{ calibrator } .onnx";
303
+ TestPipeline ( pipelineEstimators , dataView , onnxFileName , new ColumnComparison [ ] { new ColumnComparison ( "Score" , 3 ) , new ColumnComparison ( "PredictedLabel" ) , new ColumnComparison ( "Probability" , 3 ) } ) ;
298
304
}
305
+
306
+ // Step 2: Test calibrator without any binary prediction trainer
307
+ IDataView dataSoloCalibrator = ML . Data . LoadFromEnumerable ( GetCalibratorTestData ( ) ) ;
308
+ var onnxFileNameSoloCalibrator = $ "{ calibrator } -SoloCalibrator.onnx";
309
+ TestPipeline ( calibrator , dataSoloCalibrator , onnxFileNameSoloCalibrator , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
310
+
311
+ // Step 3: Test calibrator with a non-default Score column name and without any binary prediction trainer
312
+ IDataView dataSoloCalibratorNonStandard = ML . Data . LoadFromEnumerable ( GetCalibratorTestDataNonStandard ( ) ) ;
313
+ var onnxFileNameSoloCalibratorNonStandard = $ "{ calibratorNonStandard } -SoloCalibrator-NonStandard.onnx";
314
+ TestPipeline ( calibratorNonStandard , dataSoloCalibratorNonStandard , onnxFileNameSoloCalibratorNonStandard , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
315
+
299
316
Done ( ) ;
300
317
}
301
318
302
- class PlattModelInput
319
+ [ Fact ]
320
+ public void PlattCalibratorOnnxConversionTest ( )
321
+ {
322
+ CommonCalibratorOnnxConversionTest ( ML . BinaryClassification . Calibrators . Platt ( ) ,
323
+ ML . BinaryClassification . Calibrators . Platt ( scoreColumnName : "ScoreX" ) ) ;
324
+ }
325
+
326
+ [ Fact ]
327
+ public void FixedPlattCalibratorOnnxConversionTest ( )
328
+ {
329
+ // Below, FixedPlattCalibrator is utilized by defining slope and offset in Platt's constructor with sample values.
330
+ CommonCalibratorOnnxConversionTest ( ML . BinaryClassification . Calibrators . Platt ( slope : - 1f , offset : - 0.05f ) ,
331
+ ML . BinaryClassification . Calibrators . Platt ( slope : - 1f , offset : - 0.05f , scoreColumnName : "ScoreX" ) ) ;
332
+ }
333
+
334
+ [ Fact ]
335
+ public void NaiveCalibratorOnnxConversionTest ( )
336
+ {
337
+ CommonCalibratorOnnxConversionTest ( ML . BinaryClassification . Calibrators . Naive ( ) ,
338
+ ML . BinaryClassification . Calibrators . Naive ( scoreColumnName : "ScoreX" ) ) ;
339
+ }
340
+
341
+ class CalibratorInput
303
342
{
304
343
public bool Label { get ; set ; }
305
344
public float Score { get ; set ; }
306
345
}
307
346
308
- class PlattModelInput2
347
+ class CalibratorInputNonStandard
309
348
{
310
349
public bool Label { get ; set ; }
311
350
public float ScoreX { get ; set ; }
312
351
}
313
352
314
- static IEnumerable < PlattModelInput > PlattGetData ( )
353
+ static IEnumerable < CalibratorInput > GetCalibratorTestData ( )
315
354
{
316
355
for ( int i = 0 ; i < 100 ; i ++ )
317
356
{
318
- yield return new PlattModelInput { Score = i , Label = i % 2 == 0 } ;
357
+ yield return new CalibratorInput { Score = i , Label = i % 2 == 0 } ;
319
358
}
320
359
}
321
360
322
- static IEnumerable < PlattModelInput2 > PlattGetData2 ( )
361
+ static IEnumerable < CalibratorInputNonStandard > GetCalibratorTestDataNonStandard ( )
323
362
{
324
363
for ( int i = 0 ; i < 100 ; i ++ )
325
364
{
326
- yield return new PlattModelInput2 { ScoreX = i , Label = i % 2 == 0 } ;
365
+ yield return new CalibratorInputNonStandard { ScoreX = i , Label = i % 2 == 0 } ;
327
366
}
328
367
}
329
368
330
- [ Fact ]
331
- public void PlattCalibratorOnnxConversionTest2 ( )
332
- {
333
- // Test PlattCalibrator without any binary prediction trainer
334
- var mlContext = new MLContext ( seed : 0 ) ;
335
-
336
- IDataView data = mlContext . Data . LoadFromEnumerable ( PlattGetData ( ) ) ;
337
-
338
- var pipeline = mlContext . BinaryClassification . Calibrators
339
- . Platt ( ) ;
340
- var onnxFileName = $ "{ pipeline } .onnx";
341
-
342
- TestPipeline ( pipeline , data , onnxFileName , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
343
-
344
- // Test PlattCalibrator with a non-default Score column name, and without any binary prediction trainer
345
- IDataView data2 = mlContext . Data . LoadFromEnumerable ( PlattGetData2 ( ) ) ;
346
-
347
- var pipeline2 = mlContext . BinaryClassification . Calibrators
348
- . Platt ( scoreColumnName : "ScoreX" ) ;
349
- var onnxFileName2 = $ "{ pipeline2 } .onnx";
350
-
351
- TestPipeline ( pipeline2 , data2 , onnxFileName2 , new ColumnComparison [ ] { new ColumnComparison ( "Probability" , 3 ) } ) ;
352
-
353
- Done ( ) ;
354
- }
355
-
356
369
[ Fact ]
357
370
public void TextNormalizingOnnxConversionTest ( )
358
371
{
0 commit comments