Skip to content

Commit 68ea969

Browse files
authored
Added support for RankingMetrics with CrossValSummaryRunner (#5386)
* Added support for RankingMetrics with CrossValSummaryRunner * Addressed reviews * Addressed reviews * Edited naming of baselines metrics * Addressed reviews
1 parent 5370692 commit 68ea969

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

src/Microsoft.ML.AutoML/Experiment/Runners/CrossValSummaryRunner.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,31 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
144144
return result as TMetrics;
145145
}
146146

147+
if (typeof(TMetrics) == typeof(RankingMetrics))
148+
{
149+
var newMetrics = metrics.Select(x => x as RankingMetrics);
150+
Contracts.Assert(newMetrics != null);
151+
152+
var result = new RankingMetrics(
153+
dcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.DiscountedCumulativeGains)),
154+
ndcg: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.NormalizedDiscountedCumulativeGains)));
155+
return result as TMetrics;
156+
}
157+
147158
throw new NotImplementedException($"Metric {typeof(TMetrics)} not implemented");
148159
}
149160

161+
private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable<IEnumerable<double>> results)
162+
{
163+
double[] arr = new double[results.ElementAt(0).Count()];
164+
for (int i = 0; i < arr.Length; i++)
165+
{
166+
Contracts.Assert(arr.Length == results.ElementAt(i).Count());
167+
arr[i] = GetAverageOfNonNaNScores(results.Select(x => x.ElementAt(i)));
168+
}
169+
return arr;
170+
}
171+
150172
private static double GetAverageOfNonNaNScores(IEnumerable<double> results)
151173
{
152174
var newResults = results.Where(r => !double.IsNaN(r));

src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ internal DataOperationsCatalog(IHostEnvironment env)
6363
/// results.
6464
/// </summary>
6565
/// <typeparam name="TRow">The user-defined item type.</typeparam>
66-
/// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to an<see cref="IDataView"/>.</param>
66+
/// <param name="data">The enumerable data containing type <typeparamref name="TRow"/> to convert to a <see cref="IDataView"/>.</param>
6767
/// <param name="schemaDefinition">The optional schema definition of the data view to create. If <c>null</c>,
6868
/// the schema definition is inferred from <typeparamref name="TRow"/>.</param>
6969
/// <returns>The constructed <see cref="IDataView"/>.</returns>

test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ public void AutoFitRankingTest()
164164
RunDetail<RankingMetrics> bestRun = experimentResults[i].BestRun;
165165
Assert.True(experimentResults[i].RunDetails.Count() > 0);
166166
Assert.NotNull(bestRun.ValidationMetrics);
167-
Assert.True(experimentResults[i].RunDetails.Max(i => i.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > .4));
168-
Assert.True(experimentResults[i].RunDetails.Max(i => i.ValidationMetrics.DiscountedCumulativeGains.Max() > 20));
167+
Assert.True(bestRun.ValidationMetrics.NormalizedDiscountedCumulativeGains.Last() > 0.4);
168+
Assert.True(bestRun.ValidationMetrics.DiscountedCumulativeGains.Last() > 20);
169169
var outputSchema = bestRun.Model.GetOutputSchema(trainDataView.Schema);
170170
var expectedOutputNames = new string[] { labelColumnName, groupIdColumnName, groupIdColumnName, featuresColumnVectorNameA, featuresColumnVectorNameB,
171171
"Features", scoreColumnName };
@@ -187,6 +187,9 @@ public void AutoFitRankingCVTest()
187187
var reader = new TextLoader(mlContext, GetLoaderArgsRank(labelColumnName, groupIdColumnName,
188188
featuresColumnVectorNameA, featuresColumnVectorNameB));
189189
var trainDataView = reader.Load(DatasetUtil.GetMLSRDataset());
190+
// Take less than 1500 rows of data to satisfy CrossValSummaryRunner's
191+
// limit.
192+
trainDataView = mlContext.Data.TakeRows(trainDataView, 1499);
190193

191194
var experiment = mlContext.Auto()
192195
.CreateRankingExperiment(5);
@@ -208,8 +211,8 @@ public void AutoFitRankingCVTest()
208211
while (enumerator.MoveNext())
209212
{
210213
var model = enumerator.Current;
211-
Assert.True(model.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > .4);
212-
Assert.True(model.ValidationMetrics.DiscountedCumulativeGains.Max() > 19);
214+
Assert.True(model.ValidationMetrics.NormalizedDiscountedCumulativeGains.Max() > 0.31);
215+
Assert.True(model.ValidationMetrics.DiscountedCumulativeGains.Max() > 15);
213216
}
214217
}
215218
}

0 commit comments

Comments
 (0)