-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
AutoML Add Recommendation Task (#4246)
Trains Recommendation models able to predict rating for existing users
- Loading branch information
1 parent
d531ea8
commit ee8418a
Showing
36 changed files
with
687 additions
and
53 deletions.
There are no files selected for viewing
20 changes: 20 additions & 0 deletions
20
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/Movie.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples.DataStructures | ||
{ | ||
public class Movie | ||
{ | ||
[LoadColumn(0)] | ||
public string UserId; | ||
|
||
[LoadColumn(1)] | ||
public string MovieId; | ||
|
||
[LoadColumn(2)] | ||
public float Rating; | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
docs/samples/Microsoft.ML.AutoML.Samples/DataStructures/MovieRatingPrediction.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public class MovieRatingPrediction | ||
{ | ||
[ColumnName("Score")] | ||
public float Rating; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
docs/samples/Microsoft.ML.AutoML.Samples/RecommendationExperiment.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.IO; | ||
using System.Linq; | ||
using Microsoft.ML.AutoML.Samples.DataStructures; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML.Samples | ||
{ | ||
public static class RecommendationExperiment | ||
{ | ||
private static string TrainDataPath = "<Path to your train dataset goes here>"; | ||
private static string TestDataPath = "<Path to your test dataset goes here>"; | ||
private static string ModelPath = @"<Desired model output directory goes here>\Model.zip"; | ||
private static string LabelColumnName = "Rating"; | ||
private static string UserColumnName = "UserId"; | ||
private static string ItemColumnName = "MovieId"; | ||
private static uint ExperimentTime = 60; | ||
|
||
public static void Run() | ||
{ | ||
MLContext mlContext = new MLContext(); | ||
|
||
// STEP 1: Load data | ||
IDataView trainDataView = mlContext.Data.LoadFromTextFile<Movie>(TrainDataPath, hasHeader: true, separatorChar: ','); | ||
IDataView testDataView = mlContext.Data.LoadFromTextFile<Movie>(TestDataPath, hasHeader: true, separatorChar: ','); | ||
|
||
// STEP 2: Run AutoML experiment | ||
Console.WriteLine($"Running AutoML recommendation experiment for {ExperimentTime} seconds..."); | ||
ExperimentResult<RegressionMetrics> experimentResult = mlContext.Auto() | ||
.CreateRecommendationExperiment(new RecommendationExperimentSettings() { MaxExperimentTimeInSeconds = ExperimentTime }) | ||
.Execute(trainDataView, testDataView, | ||
new ColumnInformation() | ||
{ | ||
LabelColumnName = LabelColumnName, | ||
UserIdColumnName = UserColumnName, | ||
ItemIdColumnName = ItemColumnName | ||
}); | ||
|
||
// STEP 3: Print metric from best model | ||
RunDetail<RegressionMetrics> bestRun = experimentResult.BestRun; | ||
Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); | ||
Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); | ||
Console.WriteLine($"Metrics of best model from validation data --"); | ||
PrintMetrics(bestRun.ValidationMetrics); | ||
|
||
// STEP 5: Evaluate test data | ||
IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView); | ||
RegressionMetrics testMetrics = mlContext.Recommendation().Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName); | ||
Console.WriteLine($"Metrics of best model on test data --"); | ||
PrintMetrics(testMetrics); | ||
|
||
// STEP 6: Save the best model for later deployment and inferencing | ||
mlContext.Model.Save(bestRun.Model, trainDataView.Schema, ModelPath); | ||
|
||
// STEP 7: Create prediction engine from the best trained model | ||
var predictionEngine = mlContext.Model.CreatePredictionEngine<Movie, MovieRatingPrediction>(bestRun.Model); | ||
|
||
// STEP 8: Initialize a new test, and get the prediction | ||
var testMovie = new Movie | ||
{ | ||
UserId = "1", | ||
MovieId = "1097", | ||
}; | ||
var prediction = predictionEngine.Predict(testMovie); | ||
Console.WriteLine($"Predicted rating for: {prediction.Rating}"); | ||
|
||
// Only predict for existing users | ||
testMovie = new Movie | ||
{ | ||
UserId = "612", // new user | ||
MovieId = "2940" | ||
}; | ||
prediction = predictionEngine.Predict(testMovie); | ||
Console.WriteLine($"Expected Rating NaN for unknown user, Predicted: {prediction.Rating}"); | ||
|
||
Console.WriteLine("Press any key to continue..."); | ||
Console.ReadKey(); | ||
} | ||
|
||
private static void PrintMetrics(RegressionMetrics metrics) | ||
{ | ||
Console.WriteLine($"MeanAbsoluteError: {metrics.MeanAbsoluteError}"); | ||
Console.WriteLine($"MeanSquaredError: {metrics.MeanSquaredError}"); | ||
Console.WriteLine($"RootMeanSquaredError: {metrics.RootMeanSquaredError}"); | ||
Console.WriteLine($"RSquared: {metrics.RSquared}"); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Microsoft.ML.AutoML | ||
{ | ||
/// <summary> | ||
/// Settings for AutoML experiments on recommendation datasets. | ||
/// </summary> | ||
public sealed class RecommendationExperimentSettings : ExperimentSettings | ||
{ | ||
/// <summary> | ||
/// Metric that AutoML will try to optimize over the course of the experiment. | ||
/// </summary> | ||
/// <value>The default value is <see cref="RegressionMetric.RSquared"/>.</value> | ||
public RegressionMetric OptimizingMetric { get; set; } | ||
|
||
/// <summary> | ||
/// Collection of trainers the AutoML experiment can leverage. | ||
/// </summary> | ||
/// <value>The default value is a collection auto-populated with all possible trainers (all values of <see cref="RecommendationTrainer" />).</value> | ||
public ICollection<RecommendationTrainer> Trainers { get; } | ||
|
||
/// <summary> | ||
/// Initializes a new instance of <see cref="RecommendationExperimentSettings"/>. | ||
/// </summary> | ||
public RecommendationExperimentSettings() | ||
{ | ||
OptimizingMetric = RegressionMetric.RSquared; | ||
Trainers = Enum.GetValues(typeof(RecommendationTrainer)).OfType<RecommendationTrainer>().ToList(); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Enumeration of ML.NET recommendation trainers used by AutoML. | ||
/// </summary> | ||
public enum RecommendationTrainer | ||
{ | ||
MatrixFactorization | ||
} | ||
|
||
/// <summary> | ||
/// AutoML experiment on recommendation datasets. | ||
/// </summary> | ||
/// <example> | ||
/// <format type="text/markdown"> | ||
/// <![CDATA[ | ||
/// [!code-csharp[RecommendationExperiment](~/../docs/samples/docs/samples/Microsoft.ML.AutoML.Samples/RecommendationExperiment.cs)] | ||
/// ]]></format> | ||
/// </example> | ||
public sealed class RecommendationExperiment : ExperimentBase<RegressionMetrics, RecommendationExperimentSettings> | ||
{ | ||
internal RecommendationExperiment(MLContext context, RecommendationExperimentSettings settings) | ||
: base(context, | ||
new RegressionMetricsAgent(context, settings.OptimizingMetric), | ||
new OptimizingMetricInfo(settings.OptimizingMetric), | ||
settings, | ||
TaskKind.Recommendation, | ||
TrainerExtensionUtil.GetTrainerNames(settings.Trainers)) | ||
{ | ||
} | ||
|
||
private protected override CrossValidationRunDetail<RegressionMetrics> GetBestCrossValRun(IEnumerable<CrossValidationRunDetail<RegressionMetrics>> results) | ||
{ | ||
return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing); | ||
} | ||
|
||
private protected override RunDetail<RegressionMetrics> GetBestRun(IEnumerable<RunDetail<RegressionMetrics>> results) | ||
{ | ||
return BestResultUtil.GetBestRun(results, MetricsAgent, OptimizingMetricInfo.IsMaximizing); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,5 +9,6 @@ internal enum TaskKind | |
BinaryClassification, | ||
MulticlassClassification, | ||
Regression, | ||
Recommendation | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.