Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf improvement for TopK Accuracy and return all topK in Classification Evaluator #5395

Merged
merged 29 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5fbf740
Fix for issue 744
jasallen Sep 8, 2020
1747d3e
cleanup
jasallen Sep 9, 2020
32c244a
fixing report output
jasallen Sep 12, 2020
968b58d
fixedTestReferenceOutputs
jasallen Sep 12, 2020
b7ded43
Fixed test reference outputs for NetCore31
jasallen Sep 12, 2020
685eeb4
change top k acc output string format
jasallen Nov 5, 2020
1eacec7
Ranking algorithm now uses first appearance in dataset rather than wo…
jasallen Nov 6, 2020
ea057ff
fixed benchmark
jasallen Nov 6, 2020
ac08554
various minor changes from code review
jasallen Nov 6, 2020
f0de3ea
limit TopK to OutputTopKAcc parameter
jasallen Nov 6, 2020
30fbd6f
top k output name changes
jasallen Nov 6, 2020
495b4b0
make old TopK readOnly
jasallen Nov 6, 2020
c3afe15
restored old baselineOutputs since respecting outputTopK param means …
jasallen Nov 6, 2020
bfcda22
fix test fails, re-add names parameter
jasallen Nov 6, 2020
563768c
Clean up commented code
jasallen Nov 6, 2020
4a5597a
that'll teach me to edit from the github webpage
jasallen Nov 6, 2020
71390bd
use existing method, fix nits
jasallen Nov 19, 2020
32ab9fa
Slight comment change
jasallen Nov 20, 2020
db2b6b5
Comment change / Touch to kick off build pipeline
jasallen Nov 21, 2020
0d0493b
fix whitespace
jasallen Nov 23, 2020
e6aec98
Merge branch 'master' into jasallenbranch
antoniovs1029 Dec 3, 2020
05e7f91
Added new test
antoniovs1029 Dec 4, 2020
49786ed
Code formatting nits
justinormont Dec 8, 2020
9259031
Code formatting nit
justinormont Dec 8, 2020
98458ba
Fixed undefined rankofCorrectLabel and trailing whitespace warning
antoniovs1029 Dec 8, 2020
86f5c3f
Removed _numUnknownClassInstances and added test for unknown labels
antoniovs1029 Dec 8, 2020
741e9fb
Add weight to seenRanks
antoniovs1029 Dec 8, 2020
dadf793
Nits
antoniovs1029 Dec 9, 2020
9e67751
Removed FastTree import
antoniovs1029 Dec 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ private static TMetrics GetAverageMetrics(IEnumerable<TMetrics> metrics, TMetric
logLoss: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLoss)),
logLossReduction: GetAverageOfNonNaNScores(newMetrics.Select(x => x.LogLossReduction)),
topKPredictionCount: newMetrics.ElementAt(0).TopKPredictionCount,
topKAccuracy: GetAverageOfNonNaNScores(newMetrics.Select(x => x.TopKAccuracy)),
// Return PerClassLogLoss and ConfusionMatrix from the fold closest to average score
topKAccuracies: GetAverageOfNonNaNScoresInNestedEnumerable(newMetrics.Select(x => x.TopKAccuracyForAllK)),
perClassLogLoss: (metricsClosestToAvg as MulticlassClassificationMetrics).PerClassLogLoss.ToArray(),
confusionMatrix: (metricsClosestToAvg as MulticlassClassificationMetrics).ConfusionMatrix);
return result as TMetrics;
Expand Down Expand Up @@ -163,7 +162,6 @@ private static double[] GetAverageOfNonNaNScoresInNestedEnumerable(IEnumerable<I
double[] arr = new double[results.ElementAt(0).Count()];
for (int i = 0; i < arr.Length; i++)
{
Contracts.Assert(arr.Length == results.ElementAt(i).Count());
arr[i] = GetAverageOfNonNaNScores(results.Select(x => x.ElementAt(i)));
}
return arr;
Expand Down
8 changes: 7 additions & 1 deletion src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,13 @@ private static List<string> GetMetricNames(IChannel ch, DataViewSchema schema, D
names = editor.Commit();
}
foreach (var name in names.Items(all: true))
metricNames.Add(string.Format("{0}{1}", metricName, name.Value));
{
var tryNaming = string.Format(metricName, name.Value);
if (tryNaming == metricName) // metricName wasn't a format string, so just append slotname
tryNaming = (string.Format("{0}{1}", metricName, name.Value));

metricNames.Add(tryNaming);
}
}
}
ch.Assert(metricNames.Count == metricCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.Data
Expand Down Expand Up @@ -71,16 +73,22 @@ public sealed class MulticlassClassificationMetrics
public double MicroAccuracy { get; }

/// <summary>
/// If <see cref="TopKPredictionCount"/> is positive, this is the relative number of examples where
/// the true label is one of the top-k predicted labels by the predictor.
/// Convenience method for "TopKAccuracyForAllK[TopKPredictionCount - 1]". If <see cref="TopKPredictionCount"/> is positive,
/// this is the relative number of examples where
/// the true label is one of the top K predicted labels by the predictor.
/// </summary>
public double TopKAccuracy { get; }
public double TopKAccuracy => TopKAccuracyForAllK?.LastOrDefault() ?? 0;
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// If positive, this indicates the K in <see cref="TopKAccuracy"/>.
/// If positive, this indicates the K in <see cref="TopKAccuracy"/> and <see cref="TopKAccuracyForAllK"/>.
/// </summary>
public int TopKPredictionCount { get; }
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Returns the top K accuracy for all K from 1 to the value of TopKPredictionCount.
/// </summary>
public IReadOnlyList<double> TopKAccuracyForAllK { get; }

/// <summary>
/// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier
/// with respect to how much the predicted probabilities diverge from the true class label. Lower
Expand Down Expand Up @@ -115,29 +123,30 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult,
LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss);
LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction);
TopKPredictionCount = topKPredictionCount;

if (topKPredictionCount > 0)
TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy);
TopKAccuracyForAllK = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray();

var perClassLogLoss = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss);
PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray();
ConfusionMatrix = MetricWriter.GetConfusionMatrix(host, confusionMatrix, binary: false, perClassLogLoss.Length);
}

internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss)
int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss)
{
MicroAccuracy = accuracyMicro;
MacroAccuracy = accuracyMacro;
LogLoss = logLoss;
LogLossReduction = logLossReduction;
TopKPredictionCount = topKPredictionCount;
TopKAccuracy = topKAccuracy;
TopKAccuracyForAllK = topKAccuracies;
PerClassLogLoss = perClassLogLoss.ToImmutableArray();
}

internal MulticlassClassificationMetrics(double accuracyMicro, double accuracyMacro, double logLoss, double logLossReduction,
int topKPredictionCount, double topKAccuracy, double[] perClassLogLoss, ConfusionMatrix confusionMatrix)
: this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracy, perClassLogLoss)
int topKPredictionCount, double[] topKAccuracies, double[] perClassLogLoss, ConfusionMatrix confusionMatrix)
: this(accuracyMicro, accuracyMacro, logLoss, logLossReduction, topKPredictionCount, topKAccuracies, perClassLogLoss)
{
ConfusionMatrix = confusionMatrix;
}
Expand Down
Loading