Skip to content

Commit 7f94445

Browse files
add AutoMLExperiment example doc (#6594)
* add AutoMLExperiment example doc * Update AutoMLExperiment.cs * fix formatting
1 parent 6002aa8 commit 7f94445

File tree

3 files changed

+156
-0
lines changed

3 files changed

+156
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using Microsoft.ML.Data;
7+
8+
namespace Microsoft.ML.AutoML.Samples
9+
{
10+
public static class AutoMLExperiment
11+
{
12+
public static async Task RunAsync()
13+
{
14+
var seed = 0;
15+
16+
// Create a new context for ML.NET operations. It can be used for
17+
// exception tracking and logging, as a catalog of available operations
18+
// and as the source of randomness. Setting the seed to a fixed number
19+
// in this example to make outputs deterministic.
20+
var context = new MLContext(seed);
21+
22+
// Create a list of training data points and convert it to IDataView.
23+
var data = GenerateRandomBinaryClassificationDataPoints(100, seed);
24+
var dataView = context.Data.LoadFromEnumerable(data);
25+
26+
var trainTestSplit = context.Data.TrainTestSplit(dataView);
27+
28+
// Define the sweepable pipeline using predefined binary trainers and search space.
29+
var pipeline = context.Auto().BinaryClassification(labelColumnName: "Label", featureColumnName: "Features");
30+
31+
// Create an AutoML experiment
32+
var experiment = context.Auto().CreateExperiment();
33+
34+
// Redirect AutoML log to console
35+
context.Log += (object o, LoggingEventArgs e) =>
36+
{
37+
if (e.Source == nameof(AutoMLExperiment) && e.Kind > Runtime.ChannelMessageKind.Trace)
38+
{
39+
Console.WriteLine(e.RawMessage);
40+
}
41+
};
42+
43+
// Config experiment to optimize "Accuracy" metric on given dataset.
44+
// This experiment will run hyper-parameter optimization on given pipeline
45+
experiment.SetPipeline(pipeline)
46+
.SetDataset(trainTestSplit.TrainSet, fold: 5) // use 5-fold cross validation to evaluate each trial
47+
.SetBinaryClassificationMetric(BinaryClassificationMetric.Accuracy, "Label")
48+
.SetMaxModelToExplore(100); // explore 100 trials
49+
50+
// start automl experiment
51+
var result = await experiment.RunAsync();
52+
53+
// Expected output samples during training:
54+
// Update Running Trial - Id: 0
55+
// Update Completed Trial - Id: 0 - Metric: 0.5536912515402218 - Pipeline: FastTreeBinary - Duration: 595 - Peak CPU: 0.00 % -Peak Memory in MB: 35.81
56+
// Update Best Trial - Id: 0 - Metric: 0.5536912515402218 - Pipeline: FastTreeBinary
57+
58+
// evaluate test dataset on best model.
59+
var bestModel = result.Model;
60+
var eval = bestModel.Transform(trainTestSplit.TestSet);
61+
var metrics = context.BinaryClassification.Evaluate(eval);
62+
63+
PrintMetrics(metrics);
64+
65+
// Expected output:
66+
// Accuracy: 0.67
67+
// AUC: 0.75
68+
// F1 Score: 0.33
69+
// Negative Precision: 0.88
70+
// Negative Recall: 0.70
71+
// Positive Precision: 0.25
72+
// Positive Recall: 0.50
73+
74+
// TEST POSITIVE RATIO: 0.1667(2.0 / (2.0 + 10.0))
75+
// Confusion table
76+
// ||======================
77+
// PREDICTED || positive | negative | Recall
78+
// TRUTH ||======================
79+
// positive || 1 | 1 | 0.5000
80+
// negative || 3 | 7 | 0.7000
81+
// ||======================
82+
// Precision || 0.2500 | 0.8750 |
83+
}
84+
85+
private static IEnumerable<BinaryClassificationDataPoint> GenerateRandomBinaryClassificationDataPoints(int count,
86+
int seed = 0)
87+
88+
{
89+
var random = new Random(seed);
90+
float randomFloat() => (float)random.NextDouble();
91+
for (int i = 0; i < count; i++)
92+
{
93+
var label = randomFloat() > 0.5f;
94+
yield return new BinaryClassificationDataPoint
95+
{
96+
Label = label,
97+
// Create random features that are correlated with the label.
98+
// For data points with false label, the feature values are
99+
// slightly increased by adding a constant.
100+
Features = Enumerable.Repeat(label, 50)
101+
.Select(x => x ? randomFloat() : randomFloat() +
102+
0.1f).ToArray()
103+
104+
};
105+
}
106+
}
107+
108+
// Example with label and 50 feature values. A data set is a collection of
109+
// such examples.
110+
private class BinaryClassificationDataPoint
111+
{
112+
public bool Label { get; set; }
113+
114+
[VectorType(50)]
115+
public float[] Features { get; set; }
116+
}
117+
118+
// Class used to capture predictions.
119+
private class Prediction
120+
{
121+
// Original label.
122+
public bool Label { get; set; }
123+
// Predicted label from the trainer.
124+
public bool PredictedLabel { get; set; }
125+
}
126+
127+
// Pretty-print BinaryClassificationMetrics objects.
128+
private static void PrintMetrics(BinaryClassificationMetrics metrics)
129+
{
130+
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
131+
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
132+
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
133+
Console.WriteLine($"Negative Precision: " +
134+
$"{metrics.NegativePrecision:F2}");
135+
136+
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
137+
Console.WriteLine($"Positive Precision: " +
138+
$"{metrics.PositivePrecision:F2}");
139+
140+
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
141+
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
142+
}
143+
}
144+
}

docs/samples/Microsoft.ML.AutoML.Samples/Program.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ public static void Main(string[] args)
88
{
99
try
1010
{
11+
AutoMLExperiment.RunAsync().Wait();
12+
1113
RecommendationExperiment.Run();
1214
Console.Clear();
1315

src/Microsoft.ML.AutoML/AutoMLExperiment/AutoMLExperiment.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@
1818

1919
namespace Microsoft.ML.AutoML
2020
{
21+
/// <summary>
22+
/// The class for AutoML experiment
23+
/// </summary>
24+
/// <example>
25+
/// <format type="text/markdown">
26+
/// <![CDATA[
27+
/// [!code-csharp[AutoMLExperiment](~/../docs/samples/docs/samples/Microsoft.ML.AutoML.Samples/AutoMLExperiment.cs)]
28+
/// ]]>
29+
/// </format>
30+
/// </example>
2131
public class AutoMLExperiment
2232
{
2333
internal const string PipelineSearchspaceName = "_pipeline_";

0 commit comments

Comments
 (0)