Skip to content

Commit

Permalink
Added support for RankingMetrics with CrossValSummaryRunner (#5386)
Browse files Browse the repository at this point in the history
* Added support for RankingMetrics with CrossValSummaryRunner

* Addressed reviews

* Addressed reviews

* Edited naming of baselines metrics

* Addressed reviews
  • Loading branch information
mstfbl authored Sep 10, 2020
1 parent 5370692 commit 68ea969
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,31 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
return result as TMetrics;
}

if (typeof(TMetrics) == typeof(RankingMetrics))
{
var newMetrics = metrics.Select(x => x as RankingMetrics);
Contracts.Assert(newMetrics != null);

var result = new RankingMetrics(
dcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.DiscountedCumulativeGains)),
ndcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.NormalizedDiscountedCumulativeGains)));
return result as TMetrics;
}

throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented");
}

private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable<IEnumerable<double>> results)
{
double[] arr = new double[results.ElementAt(0).Count()];
for (int i = 0; i < arr.Length; i++)
{
Contracts.Assert(arr.Length == results.ElementAt(i).Count());
arr[i] = GetAverageOfNonNaNScores(results.Select(x => x.ElementAt(i)));
}
return arr;
}

private static double GetAverageOfNonNaNScores(IEnumerable<double> results)
{
var newResults = results.Where(r => !double.IsNaN(r));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ internal DataOperationsCatalog(IHostEnvironment env)
/// results.
/// </summary>
/// <typeparam name="TRow">The user-defined item type.</typeparam>
/// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to an<see cref="IDataView"/>.</param>
/// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to a <see cref="IDataView"/>.</param>
/// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
/// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
/// <returns>The constructed <see cref="IDataView"/>.</returns>
Expand Down
11 changes: 7 additions & 4 deletions test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ public void AutoFitRankingTest()
RunDetail<RankingMetrics> bestRun = experimentResults[i].BestRun;
Assert.True(experimentResults[i].RunDetails.Count() > 0);
Assert.NotNull(bestRun.ValidationMetrics);
Assert.True(experimentResults[i].RunDetails.Max(i => i.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > .4));
Assert.True(experimentResults[i].RunDetails.Max(i => i.ValidationMetrics.DiscountedCumulativeGains.Max() > 20));
Assert.True(bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Last() > 0.4);
Assert.True(bestRun.ValidationMetrics.DiscountedCumulativeGains.Last() > 20);
var outputSchema = bestRun.Model.GetOutputSchema(trainDataView.Schema);
var expectedOutputNames = new string[] { labelColumnName, groupIdColumnName, groupIdColumnName, featuresColumnVectorNameA, featuresColumnVectorNameB,
"Features", scoreColumnName };
Expand All @@ -187,6 +187,9 @@ public void AutoFitRankingCVTest()
var reader = new TextLoader(mlContext, GetLoaderArgsRank(labelColumnName, groupIdColumnName,
featuresColumnVectorNameA, featuresColumnVectorNameB));
var trainDataView = reader.Load(DatasetUtil.GetMLSRDataset());
// Take less than 1500 rows of data to satisfy CrossValSummaryRunner's
// limit.
trainDataView = mlContext.Data.TakeRows(trainDataView, 1499);

var experiment = mlContext.Auto()
.CreateRankingExperiment(5);
Expand All @@ -208,8 +211,8 @@ public void AutoFitRankingCVTest()
while (enumerator.MoveNext())
{
var model = enumerator.Current;
Assert.True(model.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > .4);
Assert.True(model.ValidationMetrics.DiscountedCumulativeGains.Max() > 19);
Assert.True(model.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > 0.31);
Assert.True(model.ValidationMetrics.DiscountedCumulativeGains.Max() > 15);
}
}
}
Expand Down

0 comments on commit 68ea969

Please sign in to comment.