Skip to content

Commit 5538ccf

Browse files
authored
Binary classification samples update (#3311)
1 parent 2e99197 commit 5538ccf

File tree

61 files changed

+2760
-702
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2760
-702
lines changed
Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,110 @@
1-
using Microsoft.ML;
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
26

37
namespace Samples.Dynamic.Trainers.BinaryClassification
48
{
59
public static class AveragedPerceptron
610
{
7-
// In this examples we will use the adult income dataset. The goal is to predict
8-
// if a person's income is above $50K or not, based on demographic information about that person.
9-
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult.
1011
public static void Example()
1112
{
1213
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
1314
// as a catalog of available operations and as the source of randomness.
1415
// Setting the seed to a fixed number in this example to make outputs deterministic.
1516
var mlContext = new MLContext(seed: 0);
1617

17-
// Download and featurize the dataset.
18-
var data = Microsoft.ML.SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
18+
// Create a list of training data points.
19+
var dataPoints = GenerateRandomDataPoints(1000);
1920

20-
// Leave out 10% of data for testing.
21-
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);
21+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
22+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
2223

23-
// Create data training pipeline.
24-
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(numberOfIterations: 10);
24+
// Define the trainer.
25+
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron();
2526

26-
// Fit this pipeline to the training data.
27-
var model = pipeline.Fit(trainTestData.TrainSet);
27+
// Train the model.
28+
var model = pipeline.Fit(trainingData);
2829

29-
// Evaluate how the model is doing on the test data.
30-
var dataWithPredictions = model.Transform(trainTestData.TestSet);
31-
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions);
32-
Microsoft.ML.SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
30+
// Create testing data. Use different random seed to make it different from training data.
31+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
3332

33+
// Run the model on test data set.
34+
var transformedTestData = model.Transform(testData);
35+
36+
// Convert IDataView object to a list.
37+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
38+
39+
// Print 5 predictions.
40+
foreach (var p in predictions.Take(5))
41+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
42+
43+
// Expected output:
44+
// Label: True, Prediction: True
45+
// Label: False, Prediction: False
46+
// Label: True, Prediction: True
47+
// Label: True, Prediction: False
48+
// Label: False, Prediction: False
49+
50+
// Evaluate the overall metrics.
51+
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(transformedTestData);
52+
PrintMetrics(metrics);
53+
3454
// Expected output:
35-
// Accuracy: 0.86
36-
// AUC: 0.91
37-
// F1 Score: 0.68
38-
// Negative Precision: 0.90
39-
// Negative Recall: 0.91
40-
// Positive Precision: 0.70
41-
// Positive Recall: 0.66
55+
// Accuracy: 0.72
56+
// AUC: 0.79
57+
// F1 Score: 0.68
58+
// Negative Precision: 0.71
59+
// Negative Recall: 0.80
60+
// Positive Precision: 0.74
61+
// Positive Recall: 0.63
62+
}
63+
64+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
65+
{
66+
var random = new Random(seed);
67+
float randomFloat() => (float)random.NextDouble();
68+
for (int i = 0; i < count; i++)
69+
{
70+
var label = randomFloat() > 0.5f;
71+
yield return new DataPoint
72+
{
73+
Label = label,
74+
// Create random features that are correlated with the label.
75+
// For data points with false label, the feature values are slightly increased by adding a constant.
76+
Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.1f).ToArray()
77+
};
78+
}
79+
}
80+
81+
// Example with label and 50 feature values. A data set is a collection of such examples.
82+
private class DataPoint
83+
{
84+
public bool Label { get; set; }
85+
[VectorType(50)]
86+
public float[] Features { get; set; }
87+
}
88+
89+
// Class used to capture predictions.
90+
private class Prediction
91+
{
92+
// Original label.
93+
public bool Label { get; set; }
94+
// Predicted label from the trainer.
95+
public bool PredictedLabel { get; set; }
96+
}
97+
98+
// Pretty-print BinaryClassificationMetrics objects.
99+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
100+
{
101+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
102+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
103+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
104+
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}");
105+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
106+
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
107+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
42108
}
43109
}
44-
}
110+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
<#@ include file="BinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName = "AveragedPerceptron";
4+
string Trainer = "AveragedPerceptron";
5+
string TrainerOptions = null;
6+
bool IsCalibrated = false;
7+
bool CacheData = false;
8+
9+
string LabelThreshold = "0.5f";
10+
string DataSepValue = "0.1f";
11+
string OptionsInclude = "";
12+
string Comments= "";
13+
14+
string ExpectedOutputPerInstance = @"// Expected output:
15+
// Label: True, Prediction: True
16+
// Label: False, Prediction: False
17+
// Label: True, Prediction: True
18+
// Label: True, Prediction: False
19+
// Label: False, Prediction: False";
20+
21+
string ExpectedOutput = @"// Expected output:
22+
// Accuracy: 0.72
23+
// AUC: 0.79
24+
// F1 Score: 0.68
25+
// Negative Precision: 0.71
26+
// Negative Recall: 0.80
27+
// Positive Precision: 0.74
28+
// Positive Recall: 0.63";
29+
#>
Lines changed: 91 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
1-
using Microsoft.ML;
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
26
using Microsoft.ML.Trainers;
37

48
namespace Samples.Dynamic.Trainers.BinaryClassification
59
{
610
public static class AveragedPerceptronWithOptions
711
{
8-
// In this examples we will use the adult income dataset. The goal is to predict
9-
// if a person's income is above $50K or not, based on demographic information about that person.
10-
// For more details about this dataset, please see https://archive.ics.uci.edu/ml/datasets/adult.
1112
public static void Example()
1213
{
1314
// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
1415
// as a catalog of available operations and as the source of randomness.
1516
// Setting the seed to a fixed number in this example to make outputs deterministic.
1617
var mlContext = new MLContext(seed: 0);
1718

18-
// Download and featurize the dataset.
19-
var data = Microsoft.ML.SamplesUtils.DatasetUtils.LoadFeaturizedAdultDataset(mlContext);
19+
// Create a list of training data points.
20+
var dataPoints = GenerateRandomDataPoints(1000);
2021

21-
// Leave out 10% of data for testing.
22-
var trainTestData = mlContext.Data.TrainTestSplit(data, testFraction: 0.1);
22+
// Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
23+
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
2324

24-
// Define the trainer options.
25-
var options = new AveragedPerceptronTrainer.Options()
25+
// Define trainer options.
26+
var options = new AveragedPerceptronTrainer.Options
2627
{
2728
LossFunction = new SmoothedHingeLoss(),
2829
LearningRate = 0.1f,
@@ -31,25 +32,90 @@ public static void Example()
3132
NumberOfIterations = 10
3233
};
3334

34-
// Create data training pipeline.
35+
// Define the trainer.
3536
var pipeline = mlContext.BinaryClassification.Trainers.AveragedPerceptron(options);
3637

37-
// Fit this pipeline to the training data.
38-
var model = pipeline.Fit(trainTestData.TrainSet);
38+
// Train the model.
39+
var model = pipeline.Fit(trainingData);
3940

40-
// Evaluate how the model is doing on the test data.
41-
var dataWithPredictions = model.Transform(trainTestData.TestSet);
42-
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(dataWithPredictions);
43-
Microsoft.ML.SamplesUtils.ConsoleUtils.PrintMetrics(metrics);
41+
// Create testing data. Use different random seed to make it different from training data.
42+
var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
43+
44+
// Run the model on test data set.
45+
var transformedTestData = model.Transform(testData);
46+
47+
// Convert IDataView object to a list.
48+
var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
49+
50+
// Print 5 predictions.
51+
foreach (var p in predictions.Take(5))
52+
Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
4453

4554
// Expected output:
46-
// Accuracy: 0.86
47-
// AUC: 0.90
48-
// F1 Score: 0.66
49-
// Negative Precision: 0.89
50-
// Negative Recall: 0.93
51-
// Positive Precision: 0.72
52-
// Positive Recall: 0.61
55+
// Label: True, Prediction: True
56+
// Label: False, Prediction: False
57+
// Label: True, Prediction: True
58+
// Label: True, Prediction: True
59+
// Label: False, Prediction: False
60+
61+
// Evaluate the overall metrics.
62+
var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(transformedTestData);
63+
PrintMetrics(metrics);
64+
65+
// Expected output:
66+
// Accuracy: 0.89
67+
// AUC: 0.96
68+
// F1 Score: 0.88
69+
// Negative Precision: 0.87
70+
// Negative Recall: 0.92
71+
// Positive Precision: 0.91
72+
// Positive Recall: 0.85
73+
}
74+
75+
private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
76+
{
77+
var random = new Random(seed);
78+
float randomFloat() => (float)random.NextDouble();
79+
for (int i = 0; i < count; i++)
80+
{
81+
var label = randomFloat() > 0.5f;
82+
yield return new DataPoint
83+
{
84+
Label = label,
85+
// Create random features that are correlated with the label.
86+
// For data points with false label, the feature values are slightly increased by adding a constant.
87+
Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + 0.1f).ToArray()
88+
};
89+
}
90+
}
91+
92+
// Example with label and 50 feature values. A data set is a collection of such examples.
93+
private class DataPoint
94+
{
95+
public bool Label { get; set; }
96+
[VectorType(50)]
97+
public float[] Features { get; set; }
98+
}
99+
100+
// Class used to capture predictions.
101+
private class Prediction
102+
{
103+
// Original label.
104+
public bool Label { get; set; }
105+
// Predicted label from the trainer.
106+
public bool PredictedLabel { get; set; }
107+
}
108+
109+
// Pretty-print BinaryClassificationMetrics objects.
110+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
111+
{
112+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
113+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
114+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
115+
Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}");
116+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
117+
Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
118+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}");
53119
}
54120
}
55-
}
121+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
<#@ include file="BinaryClassification.ttinclude"#>
2+
<#+
3+
string ClassName="AveragedPerceptronWithOptions";
4+
string Trainer = "AveragedPerceptron";
5+
bool IsCalibrated = false;
6+
7+
string LabelThreshold = "0.5f";
8+
string DataSepValue = "0.1f";
9+
string OptionsInclude = "using Microsoft.ML.Trainers;";
10+
string Comments= "";
11+
bool CacheData = false;
12+
13+
string TrainerOptions = @"AveragedPerceptronTrainer.Options
14+
{
15+
LossFunction = new SmoothedHingeLoss(),
16+
LearningRate = 0.1f,
17+
LazyUpdate = false,
18+
RecencyGain = 0.1f,
19+
NumberOfIterations = 10
20+
}";
21+
22+
string ExpectedOutputPerInstance= @"// Expected output:
23+
// Label: True, Prediction: True
24+
// Label: False, Prediction: False
25+
// Label: True, Prediction: True
26+
// Label: True, Prediction: True
27+
// Label: False, Prediction: False";
28+
29+
string ExpectedOutput = @"// Expected output:
30+
// Accuracy: 0.89
31+
// AUC: 0.96
32+
// F1 Score: 0.88
33+
// Negative Precision: 0.87
34+
// Negative Recall: 0.92
35+
// Positive Precision: 0.91
36+
// Positive Recall: 0.85";
37+
#>

0 commit comments

Comments
 (0)