Skip to content

Enabling Ranking Cross Validation #5263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
adding cross validation for ranking
  • Loading branch information
Lynx1820 committed Jul 7, 2020
commit 39f1e6a19ec4edba3a3b8f0882c79877a5b78e6c
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ public static void PrintRegressionFoldsAverageMetrics(IEnumerable<TrainCatalogBa

public static void PrintRankingFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RankingMetrics>> crossValidationResults)
{
var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? metrics.NormalizedDiscountedCumulativeGains.Count-1 : 9;
var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count-1 : 9;
var NDCG = crossValidationResults.Select(r => r.Metrics.NormalizedDiscountedCumulativeGains[max]);
var DCG = crossValidationResults.Select(r => r.Metrics.DiscountedCumulativeGains[max]);
Console.WriteLine($""*************************************************************************************************************"");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ else{#>

public static void PrintRankingFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RankingMetrics>> crossValidationResults)
{
var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? metrics.NormalizedDiscountedCumulativeGains.Count-1 : 9;
var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count-1 : 9;
var NDCG = crossValidationResults.Select(r => r.Metrics.NormalizedDiscountedCumulativeGains[max]);
var DCG = crossValidationResults.Select(r => r.Metrics.DiscountedCumulativeGains[max]);
Console.WriteLine($"*************************************************************************************************************");
Expand Down
25 changes: 25 additions & 0 deletions src/Microsoft.ML.Data/TrainCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,31 @@ public RankingMetrics Evaluate(IDataView data,
var eval = new RankingEvaluator(Environment, options ?? new RankingEvaluatorOptions() { });
return eval.Evaluate(data, labelColumnName, rowGroupColumnName, scoreColumnName);
}

/// <summary>
/// Run cross-validation over <paramref name="numberOfFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
/// and respecting <paramref name="samplingKeyColumnName"/> if provided.
/// Then evaluate each sub-model against <paramref name="labelColumnName"/> and return metrics.
/// </summary>
/// <param name="data">The data to run cross-validation on.</param>
/// <param name="estimator">The estimator to fit.</param>
/// <param name="numberOfFolds">Number of cross-validation folds.</param>
/// <param name="labelColumnName">The label column (for evaluation).</param>
/// <param name="samplingKeyColumnName">Name of a column to use for grouping rows. If two examples share the same value of the <paramref name="samplingKeyColumnName"/>,
/// they are guaranteed to appear in the same subset (train or test). This can be used to ensure no label leakage from the train to the test set.
/// If <see langword="null"/> no row grouping will be performed.</param>
/// <param name="rowGroupColumnName">The name of the groupId column in <paramref name="data"/>.</param>
/// <param name="seed">Seed for the random number generator used to select rows for cross-validation folds.</param>
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
public IReadOnlyList<CrossValidationResult<RankingMetrics>> CrossValidate(
IDataView data, IEstimator<ITransformer> estimator, int numberOfFolds = 5, string labelColumnName = DefaultColumnNames.Label,
string samplingKeyColumnName = DefaultColumnNames.GroupId, string rowGroupColumnName = DefaultColumnNames.GroupId, int ? seed = null)
{
Environment.CheckNonEmpty(labelColumnName, nameof(labelColumnName));
var result = CrossValidateTrain(data, estimator, numberOfFolds, samplingKeyColumnName, seed);
return result.Select(x => new CrossValidationResult<RankingMetrics>(x.Model,
Evaluate(x.Scores, labelColumnName, rowGroupColumnName), x.Scores, x.Fold)).ToArray();
}
}

/// <summary>
Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ internal LightGbmRankingTrainer(IHostEnvironment env, Options options)
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
/// <param name="labelColumnName">The name of the label column.</param>
/// <param name="featureColumnName">The name of the feature column.</param>
/// <param name="rowGroupdColumnName">The name of the column containing the group ID. </param>
/// <param name="rowGroupIdColumnName">The name of the column containing the group ID. </param>
/// <param name="weightsColumnName">The name of the optional column containing the initial weights.</param>
/// <param name="numberOfLeaves">The number of leaves to use.</param>
/// <param name="learningRate">The learning rate.</param>
Expand All @@ -188,7 +188,7 @@ internal LightGbmRankingTrainer(IHostEnvironment env, Options options)
internal LightGbmRankingTrainer(IHostEnvironment env,
string labelColumnName = DefaultColumnNames.Label,
string featureColumnName = DefaultColumnNames.Features,
string rowGroupdColumnName = DefaultColumnNames.GroupId,
string rowGroupIdColumnName = DefaultColumnNames.GroupId,
string weightsColumnName = null,
int? numberOfLeaves = null,
int? minimumExampleCountPerLeaf = null,
Expand All @@ -200,14 +200,14 @@ internal LightGbmRankingTrainer(IHostEnvironment env,
LabelColumnName = labelColumnName,
FeatureColumnName = featureColumnName,
ExampleWeightColumnName = weightsColumnName,
RowGroupColumnName = rowGroupdColumnName,
RowGroupColumnName = rowGroupIdColumnName,
NumberOfLeaves = numberOfLeaves,
MinimumExampleCountPerLeaf = minimumExampleCountPerLeaf,
LearningRate = learningRate,
NumberOfIterations = numberOfIterations
})
{
Host.CheckNonEmpty(rowGroupdColumnName, nameof(rowGroupdColumnName));
Host.CheckNonEmpty(rowGroupIdColumnName, nameof(rowGroupIdColumnName));
}

private protected override void CheckDataValid(IChannel ch, RoleMappedData data)
Expand Down
26 changes: 26 additions & 0 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Microsoft.ML.TestFramework.Attributes;
using Microsoft.ML.TestFrameworkCommon;
using Xunit;
using Microsoft.ML.Trainers.LightGbm;
using Xunit.Abstractions;
using static Microsoft.ML.DataOperationsCatalog;

Expand Down Expand Up @@ -156,6 +157,31 @@ public void AutoFitRankingTest()
Assert.True(col.Name == expectedOutputNames[col.Index]);
}

[LightGBMFact]
public void AutoFitRankingCVTest()
Copy link
Contributor Author

@Lynx1820 Lynx1820 Jun 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the way experiments are used within codegen.
Review: Should I add cross validation tests to all other experiments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add cross validation tests to all other experiments?

If I recall it correctly, if your dataset has less than 15000 lines of data, AutoML will run CrossValidation automatically, if you have more than 15000 piece of data, it will use train-test split instead. So the rest of tests in AutoFitTests should all be CV runs considering that the dataset it uses is really small. (@justinormont correct me if I'm wrong)

tests start with AutoFit should test AutoML ranking experiment API, so you shouldn't have to create your pipeline from scratch in this test, If you just want to test Ranking.CrossValidation command, considering rename it more specifically.

Copy link
Member

@antoniovs1029 antoniovs1029 Jul 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the other way around. If it has less than 15000 it runs train test split automatically on one of the Execute overloads, if it has more it runs CV. This only happens on 1 overload, but I believe Keren isn't using that overload on her tests.


In reply to: 447156630 [](ancestors = 447156630)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added CV testing for ranking only. I think it would be good to add testing for other task as well in the future.

{
string labelColumnName = "Label";
string groupIdColumnName = "GroupId";
string featuresColumnVectorNameA = "FeatureVectorA";
string featuresColumnVectorNameB = "FeatureVectorB";
int numFolds = 3;

var mlContext = new MLContext(1);
var dataProcessPipeline = mlContext.Transforms.Concatenate("Features", new[] { "FeatureVectorA", "FeatureVectorB" }).Append(
mlContext.Transforms.Conversion.Hash("GroupId", "GroupId"));

var trainer = mlContext.Ranking.Trainers.LightGbm(new LightGbmRankingTrainer.Options() { RowGroupColumnName = "GroupId", LabelColumnName = "Label", FeatureColumnName = "Features" });
var reader = new TextLoader(mlContext, GetLoaderArgsRank(labelColumnName, groupIdColumnName, featuresColumnVectorNameA, featuresColumnVectorNameB));
var trainDataView = reader.Load(new MultiFileSource(DatasetUtil.GetMLSRDataset()));
var trainingPipeline = dataProcessPipeline.Append(trainer);
var result = mlContext.Ranking.CrossValidate(trainDataView, trainingPipeline, numberOfFolds: numFolds);
for (int i = 0; i < numFolds; i++)
{
Assert.True(result[i].Metrics.NormalizedDiscountedCumulativeGains.Max() > .4);
Assert.True(result[i].Metrics.DiscountedCumulativeGains.Max() > 16);
}
}

[Fact]
public void AutoFitRecommendationTest()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using TestNamespace.Model;
using Microsoft.ML.Trainers.LightGbm;

namespace TestNamespace.ConsoleApp
{
Expand Down Expand Up @@ -58,7 +57,7 @@ namespace TestNamespace.ConsoleApp
// Data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Conversion.Hash("GroupId", "GroupId");
// Set the training algorithm
var trainer = mlContext.Ranking.Trainers.LightGbm(new LightGbmRankingTrainer.Options() { rowGroupColumnName = "GroupId", LabelColumnName = "Label", FeatureColumnName = "Features" });
var trainer = mlContext.Ranking.Trainers.LightGbm(rowGroupColumnName: "GroupId", labelColumnName: "Label", featureColumnName: "Features");

var trainingPipeline = dataProcessPipeline.Append(trainer);

Expand Down Expand Up @@ -115,7 +114,7 @@ namespace TestNamespace.ConsoleApp

public static void PrintRankingFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<RankingMetrics>> crossValidationResults)
{
var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? metrics.NormalizedDiscountedCumulativeGains.Count - 1 : 9;
var max = (crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count < 10) ? crossValidationResults.First().Metrics.NormalizedDiscountedCumulativeGains.Count - 1 : 9;
var NDCG = crossValidationResults.Select(r => r.Metrics.NormalizedDiscountedCumulativeGains[max]);
var DCG = crossValidationResults.Select(r => r.Metrics.DiscountedCumulativeGains[max]);
Console.WriteLine($"*************************************************************************************************************");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,9 @@ private CodeGenerator PrepareForRecommendationTask()
{
if (_mockedPipeline == null)
{
MLContext context = new MLContext();
var hyperParam = new Dictionary<string, object>()
{
{"rowGroupColumnName","GroupId" },
{"RowGroupColumnName","GroupId" },
{"LabelColumnName","Label" },
};
var hashPipelineNode = new PipelineNode(nameof(EstimatorName.Hashing), PipelineNodeType.Transform, "GroupId", "GroupId");
Expand Down