-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AutoML] Add AutoML example code (#3458)
- Loading branch information
Showing
16 changed files
with
402 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
docs/samples/Microsoft.ML.AutoML.Samples/BinaryClassificationExperiment.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
using System; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.ML.Auto; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public static class BinaryClassificationExperiment | ||
{ | ||
private static string TrainDataPath = "<Path to your train dataset goes here>"; | ||
private static string TestDataPath = "<Path to your test dataset goes here>"; | ||
private static string ModelPath = @"<Desired model output directory goes here>\SentimentModel.zip"; | ||
private static uint ExperimentTime = 60; | ||
|
||
public static void Run() | ||
{ | ||
MLContext mlContext = new MLContext(); | ||
|
||
// STEP 1: Load data | ||
IDataView trainDataView = mlContext.Data.LoadFromTextFile<SentimentIssue>(TrainDataPath, hasHeader: true); | ||
IDataView testDataView = mlContext.Data.LoadFromTextFile<SentimentIssue>(TestDataPath, hasHeader: true); | ||
|
||
// STEP 2: Run AutoML experiment | ||
Console.WriteLine($"Running AutoML binary classification experiment for {ExperimentTime} seconds..."); | ||
ExperimentResult<BinaryClassificationMetrics> experimentResult = mlContext.Auto() | ||
.CreateBinaryClassificationExperiment(ExperimentTime) | ||
.Execute(trainDataView); | ||
|
||
// STEP 3: Print metric from the best model | ||
RunDetail<BinaryClassificationMetrics> bestRun = experimentResult.BestRun; | ||
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); | ||
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); | ||
Console.WriteLine($"Metrics of best model from validation data --"); | ||
PrintMetrics(bestRun.ValidationMetrics); | ||
|
||
// STEP 4: Evaluate test data | ||
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView); | ||
BinaryClassificationMetrics testMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testDataViewWithBestScore); | ||
Console.WriteLine($"Metrics of best model on test data --"); | ||
PrintMetrics(testMetrics); | ||
|
||
// STEP 5: Save the best model for later deployment and inferencing | ||
using (FileStream fs = File.Create(ModelPath)) | ||
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, fs); | ||
|
||
// STEP 6: Create prediction engine from the best trained model | ||
var predictionEngine = mlContext.Model.CreatePredictionEngine<SentimentIssue, SentimentPrediction>(bestRun.Model); | ||
|
||
// STEP 7: Initialize a new sentiment issue, and get the predicted sentiment | ||
var testSentimentIssue = new SentimentIssue | ||
{ | ||
Text = "I hope this helps." | ||
}; | ||
var prediction = predictionEngine.Predict(testSentimentIssue); | ||
Console.WriteLine($"Predicted sentiment for test issue: {prediction.Prediction}"); | ||
|
||
Console.WriteLine("Press any key to continue..."); | ||
Console.ReadKey(); | ||
} | ||
|
||
private static void PrintMetrics(BinaryClassificationMetrics metrics) | ||
{ | ||
Console.WriteLine($"Accuracy: {metrics.Accuracy}"); | ||
Console.WriteLine($"AreaUnderPrecisionRecallCurve: {metrics.AreaUnderPrecisionRecallCurve}"); | ||
Console.WriteLine($"AreaUnderRocCurve: {metrics.AreaUnderRocCurve}"); | ||
Console.WriteLine($"F1Score: {metrics.F1Score}"); | ||
Console.WriteLine($"NegativePrecision: {metrics.NegativePrecision}"); | ||
Console.WriteLine($"NegativeRecall: {metrics.NegativeRecall}"); | ||
Console.WriteLine($"PositivePrecision: {metrics.PositivePrecision}"); | ||
Console.WriteLine($"PositiveRecall: {metrics.PositiveRecall}"); | ||
} | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/PixelData.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class PixelData | ||
{ | ||
[LoadColumn(0, 63)] | ||
[VectorType(64)] | ||
public float[] PixelValues; | ||
|
||
[LoadColumn(64)] | ||
public float Number; | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/PixelPrediction.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class PixelPrediction | ||
{ | ||
[ColumnName("PredictedLabel")] | ||
public float Prediction; | ||
} | ||
} |
13 changes: 13 additions & 0 deletions
13
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/SentimentIssue.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class SentimentIssue | ||
{ | ||
[LoadColumn(0)] | ||
public bool Label { get; set; } | ||
|
||
[LoadColumn(1)] | ||
public string Text { get; set; } | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/SentimentPrediction.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class SentimentPrediction | ||
{ | ||
// ColumnName attribute is used to change the column name from | ||
// its default value, which is the name of the field. | ||
[ColumnName("PredictedLabel")] | ||
public bool Prediction { get; set; } | ||
|
||
public float Score { get; set; } | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/TaxiTrip.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class TaxiTrip | ||
{ | ||
[LoadColumn(0)] | ||
public string VendorId; | ||
|
||
[LoadColumn(1)] | ||
public float RateCode; | ||
|
||
[LoadColumn(2)] | ||
public float PassengerCount; | ||
|
||
[LoadColumn(3)] | ||
public float TripTimeInSeconds; | ||
|
||
[LoadColumn(4)] | ||
public float TripDistance; | ||
|
||
[LoadColumn(5)] | ||
public string PaymentType; | ||
|
||
[LoadColumn(6)] | ||
public float FareAmount; | ||
} | ||
} |
10 changes: 10 additions & 0 deletions
10
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/TaxiTripFarePrediction.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class TaxiTripFarePrediction | ||
{ | ||
[ColumnName("Score")] | ||
public float FareAmount; | ||
} | ||
} |
12 changes: 12 additions & 0 deletions
12
docs/samples/Microsoft.ML.AutoML.Samples/Microsoft.ML.AutoML.Samples.csproj
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
<Project Sdk="Microsoft.NET.Sdk"> | ||
|
||
<PropertyGroup> | ||
<OutputType>Exe</OutputType> | ||
<TargetFramework>netcoreapp2.1</TargetFramework> | ||
</PropertyGroup> | ||
|
||
<ItemGroup> | ||
<ProjectReference Include="..\..\..\src\Microsoft.ML.Auto\Microsoft.ML.Auto.csproj" /> | ||
</ItemGroup> | ||
|
||
</Project> |
71 changes: 71 additions & 0 deletions
71
docs/samples/Microsoft.ML.AutoML.Samples/MulticlassClassificationExperiment.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
using System; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.ML.Auto; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public static class MulticlassClassificationExperiment | ||
{ | ||
private static string TrainDataPath = "<Path to your train dataset goes here>"; | ||
private static string TestDataPath = "<Path to your test dataset goes here>"; | ||
private static string ModelPath = @"<Desired model output directory goes here>\OptDigitsModel.zip"; | ||
private static string LabelColumnName = "Number"; | ||
private static uint ExperimentTime = 60; | ||
|
||
public static void Run() | ||
{ | ||
MLContext mlContext = new MLContext(); | ||
|
||
// STEP 1: Load data | ||
IDataView trainDataView = mlContext.Data.LoadFromTextFile<PixelData>(TrainDataPath, separatorChar: ','); | ||
IDataView testDataView = mlContext.Data.LoadFromTextFile<PixelData>(TestDataPath, separatorChar: ','); | ||
|
||
// STEP 2: Run AutoML experiment | ||
Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds..."); | ||
ExperimentResult<MulticlassClassificationMetrics> experimentResult = mlContext.Auto() | ||
.CreateMulticlassClassificationExperiment(ExperimentTime) | ||
.Execute(trainDataView, LabelColumnName); | ||
|
||
// STEP 3: Print metric from the best model | ||
RunDetail<MulticlassClassificationMetrics> bestRun = experimentResult.BestRun; | ||
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); | ||
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); | ||
Console.WriteLine($"Metrics of best model from validation data --"); | ||
PrintMetrics(bestRun.ValidationMetrics); | ||
|
||
// STEP 4: Evaluate test data | ||
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView); | ||
MulticlassClassificationMetrics testMetrics = mlContext.MulticlassClassification.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName); | ||
Console.WriteLine($"Metrics of best model on test data --"); | ||
PrintMetrics(testMetrics); | ||
|
||
// STEP 5: Save the best model for later deployment and inferencing | ||
using (FileStream fs = File.Create(ModelPath)) | ||
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, fs); | ||
|
||
// STEP 6: Create prediction engine from the best trained model | ||
var predictionEngine = mlContext.Model.CreatePredictionEngine<PixelData, PixelPrediction>(bestRun.Model); | ||
|
||
// STEP 7: Initialize new pixel data, and get the predicted number | ||
var testPixelData = new PixelData | ||
{ | ||
PixelValues = new float[] { 0, 0, 1, 8, 15, 10, 0, 0, 0, 3, 13, 15, 14, 14, 0, 0, 0, 5, 10, 0, 10, 12, 0, 0, 0, 0, 3, 5, 15, 10, 2, 0, 0, 0, 16, 16, 16, 16, 12, 0, 0, 1, 8, 12, 14, 8, 3, 0, 0, 0, 0, 10, 13, 0, 0, 0, 0, 0, 0, 11, 9, 0, 0, 0 } | ||
}; | ||
var prediction = predictionEngine.Predict(testPixelData); | ||
Console.WriteLine($"Predicted number for test pixels: {prediction.Prediction}"); | ||
|
||
Console.WriteLine("Press any key to continue..."); | ||
Console.ReadKey(); | ||
} | ||
|
||
private static void PrintMetrics(MulticlassClassificationMetrics metrics) | ||
{ | ||
Console.WriteLine($"LogLoss: {metrics.LogLoss}"); | ||
Console.WriteLine($"LogLossReduction: {metrics.LogLossReduction}"); | ||
Console.WriteLine($"MacroAccuracy: {metrics.MacroAccuracy}"); | ||
Console.WriteLine($"MicroAccuracy: {metrics.MicroAccuracy}"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
using System; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class Program | ||
{ | ||
public static void Main(string[] args) | ||
{ | ||
try | ||
{ | ||
RegressionExperiment.Run(); | ||
Console.Clear(); | ||
|
||
BinaryClassificationExperiment.Run(); | ||
Console.Clear(); | ||
|
||
MulticlassClassificationExperiment.Run(); | ||
Console.Clear(); | ||
|
||
Console.WriteLine("Done"); | ||
} | ||
catch (Exception ex) | ||
{ | ||
Console.WriteLine($"Exception {ex}"); | ||
} | ||
|
||
Console.ReadLine(); | ||
} | ||
} | ||
} |
76 changes: 76 additions & 0 deletions
76
docs/samples/Microsoft.ML.AutoML.Samples/RegressionExperiment.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
using System; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.ML.Auto; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public static class RegressionExperiment | ||
{ | ||
private static string TrainDataPath = "<Path to your train dataset goes here>"; | ||
private static string TestDataPath = "<Path to your test dataset goes here>"; | ||
private static string ModelPath = @"<Desired model output directory goes here>\TaxiFareModel.zip"; | ||
private static string LabelColumnName = "FareAmount"; | ||
private static uint ExperimentTime = 60; | ||
|
||
public static void Run() | ||
{ | ||
MLContext mlContext = new MLContext(); | ||
|
||
// STEP 1: Load data | ||
IDataView trainDataView = mlContext.Data.LoadFromTextFile<TaxiTrip>(TrainDataPath, hasHeader: true, separatorChar: ','); | ||
IDataView testDataView = mlContext.Data.LoadFromTextFile<TaxiTrip>(TestDataPath, hasHeader: true, separatorChar: ','); | ||
|
||
// STEP 2: Run AutoML experiment | ||
Console.WriteLine($"Running AutoML regression experiment for {ExperimentTime} seconds..."); | ||
ExperimentResult<RegressionMetrics> experimentResult = mlContext.Auto() | ||
.CreateRegressionExperiment(ExperimentTime) | ||
.Execute(trainDataView, LabelColumnName); | ||
|
||
// STEP 3: Print metric from best model | ||
RunDetail<RegressionMetrics> bestRun = experimentResult.BestRun; | ||
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); | ||
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); | ||
Console.WriteLine($"Metrics of best model from validation data --"); | ||
PrintMetrics(bestRun.ValidationMetrics); | ||
|
||
// STEP 5: Evaluate test data | ||
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView); | ||
RegressionMetrics testMetrics = mlContext.Regression.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName); | ||
Console.WriteLine($"Metrics of best model on test data --"); | ||
PrintMetrics(testMetrics); | ||
|
||
// STEP 6: Save the best model for later deployment and inferencing | ||
using (FileStream fs = File.Create(ModelPath)) | ||
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, fs); | ||
|
||
// STEP 7: Create prediction engine from the best trained model | ||
var predictionEngine = mlContext.Model.CreatePredictionEngine<TaxiTrip, TaxiTripFarePrediction>(bestRun.Model); | ||
|
||
// STEP 8: Initialize a new test taxi trip, and get the predicted fare | ||
var testTaxiTrip = new TaxiTrip | ||
{ | ||
VendorId = "VTS", | ||
RateCode = 1, | ||
PassengerCount = 1, | ||
TripTimeInSeconds = 1140, | ||
TripDistance = 3.75f, | ||
PaymentType = "CRD" | ||
}; | ||
var prediction = predictionEngine.Predict(testTaxiTrip); | ||
Console.WriteLine($"Predicted fare for test taxi trip: {prediction.FareAmount}"); | ||
|
||
Console.WriteLine("Press any key to continue..."); | ||
Console.ReadKey(); | ||
} | ||
|
||
private static void PrintMetrics(RegressionMetrics metrics) | ||
{ | ||
Console.WriteLine($"MeanAbsoluteError: {metrics.MeanAbsoluteError}"); | ||
Console.WriteLine($"MeanSquaredError: {metrics.MeanSquaredError}"); | ||
Console.WriteLine($"RootMeanSquaredError: {metrics.RootMeanSquaredError}"); | ||
Console.WriteLine($"RSquared: {metrics.RSquared}"); | ||
} | ||
} | ||
} |
Oops, something went wrong.