Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf improvement for TopK Accuracy and return all topK in Classification Evaluator #5395

Merged
merged 29 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5fbf740
Fix for issue 744
jasallen Sep 8, 2020
1747d3e
cleanup
jasallen Sep 9, 2020
32c244a
fixing report output
jasallen Sep 12, 2020
968b58d
fixedTestReferenceOutputs
jasallen Sep 12, 2020
b7ded43
Fixed test reference outputs for NetCore31
jasallen Sep 12, 2020
685eeb4
change top k acc output string format
jasallen Nov 5, 2020
1eacec7
Ranking algorithm now uses first appearance in dataset rather than wo…
jasallen Nov 6, 2020
ea057ff
fixed benchmark
jasallen Nov 6, 2020
ac08554
various minor changes from code review
jasallen Nov 6, 2020
f0de3ea
limit TopK to OutputTopKAcc parameter
jasallen Nov 6, 2020
30fbd6f
top k output name changes
jasallen Nov 6, 2020
495b4b0
make old TopK readOnly
jasallen Nov 6, 2020
c3afe15
restored old baselineOutputs since respecting outputTopK param means …
jasallen Nov 6, 2020
bfcda22
fix test fails, re-add names parameter
jasallen Nov 6, 2020
563768c
Clean up commented code
jasallen Nov 6, 2020
4a5597a
that'll teach me to edit from the github webpage
jasallen Nov 6, 2020
71390bd
use existing method, fix nits
jasallen Nov 19, 2020
32ab9fa
Slight comment change
jasallen Nov 20, 2020
db2b6b5
Comment change / Touch to kick off build pipeline
jasallen Nov 21, 2020
0d0493b
fix whitespace
jasallen Nov 23, 2020
e6aec98
Merge branch 'master' into jasallenbranch
antoniovs1029 Dec 3, 2020
05e7f91
Added new test
antoniovs1029 Dec 4, 2020
49786ed
Code formatting nits
justinormont Dec 8, 2020
9259031
Code formatting nit
justinormont Dec 8, 2020
98458ba
Fixed undefined rankofCorrectLabel and trailing whitespace warning
antoniovs1029 Dec 8, 2020
86f5c3f
Removed _numUnknownClassInstances and added test for unknown labels
antoniovs1029 Dec 8, 2020
741e9fb
Add weight to seenRanks
antoniovs1029 Dec 8, 2020
dadf793
Nits
antoniovs1029 Dec 9, 2020
9e67751
Removed FastTree import
antoniovs1029 Dec 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Microsoft.ML.Runtime;

namespace Microsoft.ML.Data
Expand Down Expand Up @@ -81,6 +82,11 @@ public sealed class MulticlassClassificationMetrics
/// </summary>
public int TopKPredictionCount { get; }
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved

/// <summary>
/// Returns the top K for all K from 1 to the number of classes
/// </summary>
public IReadOnlyList<double> TopKAccuracyForAllK { get; }

/// <summary>
/// Gets the log-loss of the classifier for each class. Log-loss measures the performance of a classifier
/// with respect to how much the predicted probabilities diverge from the true class label. Lower
Expand Down Expand Up @@ -114,9 +120,10 @@ internal MulticlassClassificationMetrics(IHost host, DataViewRow overallResult,
MacroAccuracy = FetchDouble(MulticlassClassificationEvaluator.AccuracyMacro);
LogLoss = FetchDouble(MulticlassClassificationEvaluator.LogLoss);
LogLossReduction = FetchDouble(MulticlassClassificationEvaluator.LogLossReduction);
TopKAccuracyForAllK = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.AllTopKAccuracy).DenseValues().ToImmutableArray();
jasallen marked this conversation as resolved.
Show resolved Hide resolved
TopKPredictionCount = topKPredictionCount;
if (topKPredictionCount > 0)
TopKAccuracy = FetchDouble(MulticlassClassificationEvaluator.TopKAccuracy);
TopKAccuracy = TopKAccuracyForAllK[topKPredictionCount-1];

var perClassLogLoss = RowCursorUtils.Fetch<VBuffer<double>>(host, overallResult, MulticlassClassificationEvaluator.PerClassLogLoss);
PerClassLogLoss = perClassLogLoss.DenseValues().ToImmutableArray();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public sealed class Arguments
public const string AccuracyMicro = "Accuracy(micro-avg)";
public const string AccuracyMacro = "Accuracy(macro-avg)";
public const string TopKAccuracy = "Top K accuracy";
public const string AllTopKAccuracy = "Top K accuracy(All K)";
public const string PerClassLogLoss = "Per class log-loss";
public const string LogLoss = "Log-loss";
public const string LogLossReduction = "Log-loss reduction";
Expand All @@ -60,15 +61,13 @@ public enum Metrics
internal const string LoadName = "MultiClassClassifierEvaluator";

private readonly int? _outputTopKAcc;
private readonly bool _names;
jasallen marked this conversation as resolved.
Show resolved Hide resolved

public MulticlassClassificationEvaluator(IHostEnvironment env, Arguments args)
: base(env, LoadName)
{
Host.AssertValue(args, "args");
Host.CheckUserArg(args.OutputTopKAcc == null || args.OutputTopKAcc > 0, nameof(args.OutputTopKAcc));
_outputTopKAcc = args.OutputTopKAcc;
_names = args.Names;
}

private protected override void CheckScoreAndLabelTypes(RoleMappedSchema schema)
Expand Down Expand Up @@ -147,6 +146,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
var logLoss = new List<double>();
var logLossRed = new List<double>();
var topKAcc = new List<double>();
var allTopK = new List<double[]>();
var perClassLogLoss = new List<double[]>();
var counts = new List<double[]>();
var weights = new List<double[]>();
Expand All @@ -172,6 +172,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
logLossRed.Add(agg.UnweightedCounters.Reduction);
if (agg.UnweightedCounters.OutputTopKAcc > 0)
topKAcc.Add(agg.UnweightedCounters.TopKAccuracy);
allTopK.Add(agg.UnweightedCounters.AllTopKAccuracy);
jasallen marked this conversation as resolved.
Show resolved Hide resolved
perClassLogLoss.Add(agg.UnweightedCounters.PerClassLogLoss);

confStratCol.AddRange(agg.UnweightedCounters.ConfusionTable.Select(x => stratColKey));
Expand All @@ -189,6 +190,7 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
logLossRed.Add(agg.WeightedCounters.Reduction);
if (agg.WeightedCounters.OutputTopKAcc > 0)
topKAcc.Add(agg.WeightedCounters.TopKAccuracy);
allTopK.Add(agg.WeightedCounters.AllTopKAccuracy);
perClassLogLoss.Add(agg.WeightedCounters.PerClassLogLoss);
weights.AddRange(agg.WeightedCounters.ConfusionTable);
}
Expand All @@ -213,6 +215,11 @@ private protected override void GetAggregatorConsolidationFuncs(Aggregator aggre
overallDvBldr.AddColumn(TopKAccuracy, NumberDataViewType.Double, topKAcc.ToArray());
overallDvBldr.AddColumn(PerClassLogLoss, aggregator.GetSlotNames, NumberDataViewType.Double, perClassLogLoss.ToArray());

ValueGetter<VBuffer<ReadOnlyMemory<char>>> getKSlotNames =
jasallen marked this conversation as resolved.
Show resolved Hide resolved
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
dst = new VBuffer<ReadOnlyMemory<char>>(allTopK.First().Length, Enumerable.Range(1,allTopK.First().Length).Select(i=>new ReadOnlyMemory<char>(($"K={i.ToString()}").ToCharArray())).ToArray());
overallDvBldr.AddColumn(AllTopKAccuracy, getKSlotNames, NumberDataViewType.Double, allTopK.ToArray());

var confDvBldr = new ArrayDataViewBuilder(Host);
if (hasStrats)
{
Expand Down Expand Up @@ -246,9 +253,11 @@ public sealed class Counters
private double _totalLogLoss;
private double _numInstances;
private double _numCorrect;
private double _numCorrectTopK;
private int _numUnknownClassInstances;
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
private readonly double[] _sumWeightsOfClass;
private readonly double[] _totalPerClassLogLoss;
private readonly long[] _seenRanks;

public readonly double[][] ConfusionTable;

public double MicroAvgAccuracy { get { return _numInstances > 0 ? _numCorrect / _numInstances : 0; } }
Expand Down Expand Up @@ -291,7 +300,8 @@ public double Reduction
}
}

public double TopKAccuracy { get { return _numInstances > 0 ? _numCorrectTopK / _numInstances : 0; } }
public double TopKAccuracy => !(OutputTopKAcc is null) ? AllTopKAccuracy[OutputTopKAcc.Value] : 0d;
public double[] AllTopKAccuracy => CumulativeSum(_seenRanks.Select(l => l / (double)(_numInstances - _numUnknownClassInstances))).ToArray();

// The per class average log loss is calculated by dividing the weighted sum of the log loss of examples
// in each class by the total weight of examples in that class.
Expand All @@ -316,14 +326,12 @@ public Counters(int numClasses, int? outputTopKAcc)
ConfusionTable = new double[numClasses][];
for (int i = 0; i < ConfusionTable.Length; i++)
ConfusionTable[i] = new double[numClasses];

_seenRanks = new long[numClasses + 1];
}

public void Update(int[] indices, double loglossCurr, int label, float weight)
public void Update(int seenRank, int assigned, double loglossCurr, int label, float weight)
{
Contracts.Assert(Utils.Size(indices) == _numClasses);

int assigned = indices[0];

_numInstances += weight;

if (label < _numClasses)
Expand All @@ -334,23 +342,34 @@ public void Update(int[] indices, double loglossCurr, int label, float weight)
if (label < _numClasses)
_totalPerClassLogLoss[label] += loglossCurr * weight;

if (assigned == label)
_seenRanks[seenRank]++;
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved

if (seenRank == 0) //prediction matched label
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
{
_numCorrect += weight;
ConfusionTable[label][label] += weight;
_numCorrectTopK += weight;
}
else if (label < _numClasses)
{
if (OutputTopKAcc > 0)
{
int idx = Array.IndexOf(indices, label);
if (0 <= idx && idx < OutputTopKAcc)
_numCorrectTopK += weight;
}
ConfusionTable[label][assigned] += weight;
}
else
{
_numUnknownClassInstances++;
}
}

private static IEnumerable<double> CumulativeSum(IEnumerable<double> s)
{
double sum = 0;
;
jasallen marked this conversation as resolved.
Show resolved Hide resolved
foreach (var x in s)
{
sum += x;
yield return sum;
}
}

}

private ValueGetter<float> _labelGetter;
Expand All @@ -359,7 +378,6 @@ public void Update(int[] indices, double loglossCurr, int label, float weight)

private VBuffer<float> _scores;
private readonly float[] _scoresArr;
private int[] _indicesArr;

private const float Epsilon = (float)1e-15;

Expand All @@ -380,6 +398,7 @@ public Aggregator(IHostEnvironment env, ReadOnlyMemory<char>[] classNames, int s
Host.Assert(Utils.Size(classNames) == scoreVectorSize);

_scoresArr = new float[scoreVectorSize];

UnweightedCounters = new Counters(scoreVectorSize, outputTopKAcc);
Weighted = weighted;
WeightedCounters = Weighted ? new Counters(scoreVectorSize, outputTopKAcc) : null;
Expand All @@ -400,6 +419,7 @@ internal override void InitializeNextPass(DataViewRow row, RoleMappedSchema sche

if (schema.Weight.HasValue)
_weightGetter = row.GetGetter<float>(schema.Weight.Value);

}

public override void ProcessRow()
Expand Down Expand Up @@ -437,16 +457,12 @@ public override void ProcessRow()
}
}

// Sort classes by prediction strength.
// Use stable OrderBy instead of Sort(), which may give different results on different machines.
if (Utils.Size(_indicesArr) < _scoresArr.Length)
_indicesArr = new int[_scoresArr.Length];
int j = 0;
foreach (var index in Enumerable.Range(0, _scoresArr.Length).OrderByDescending(i => _scoresArr[i]))
_indicesArr[j++] = index;

var intLabel = (int)label;

var assigned = Array.IndexOf(_scoresArr, _scoresArr.Max()); //perf could be improved

var wasKnownLabel = true;

// log-loss
double logloss;
if (intLabel < _scoresArr.Length)
Expand All @@ -461,11 +477,21 @@ public override void ProcessRow()
// Penalize logloss if the label was not seen during training
logloss = -Math.Log(Epsilon);
_numUnknownClassInstances++;
wasKnownLabel = false;
}

UnweightedCounters.Update(_indicesArr, logloss, intLabel, 1);
// Get the probability that the CORRECT label has: (best case is that it's the highest probability):
var correctProba = !wasKnownLabel ? 0 : _scoresArr[intLabel];

// Find the rank of the *correct* label (in Scores[]). If 0 => Good, correct. And the lower the better.
jasallen marked this conversation as resolved.
Show resolved Hide resolved
// The rank will be from 0 to N. (Not N-1).
jasallen marked this conversation as resolved.
Show resolved Hide resolved
// Problem: What if we have probabilities that are equal to the correct prediction (eg, .6 .1 .1 .1 .1).
// This actually happens a lot with some models. Here we assign the worst rank in the case of a tie (so 4 in this example)
var correctRankWorstCase = !wasKnownLabel ? _scoresArr.Length : _scoresArr.Count(score => score >= correctProba) - 1;

UnweightedCounters.Update(correctRankWorstCase, assigned, logloss, intLabel, 1);
if (WeightedCounters != null)
WeightedCounters.Update(_indicesArr, logloss, intLabel, weight);
WeightedCounters.Update(correctRankWorstCase, assigned, logloss, intLabel, weight);
}

protected override List<string> GetWarningsCore()
Expand Down Expand Up @@ -909,6 +935,7 @@ private protected override IDataView CombineOverallMetricsCore(IDataView[] metri
for (int i = 0; i < metrics.Length; i++)
{
var idv = metrics[i];
idv = DropAllTopKColumn(idv);
if (!_outputPerClass)
idv = DropPerClassColumn(idv);

Expand Down Expand Up @@ -964,6 +991,15 @@ private IDataView DropPerClassColumn(IDataView input)
return input;
}

private IDataView DropAllTopKColumn(IDataView input)
jasallen marked this conversation as resolved.
Show resolved Hide resolved
{
if (input.Schema.TryGetColumnIndex(MulticlassClassificationEvaluator.AllTopKAccuracy, out int AllTopKCol))
{
input = ColumnSelectingTransformer.CreateDrop(Host, input, MulticlassClassificationEvaluator.AllTopKAccuracy);
}
return input;
}

public override IEnumerable<MetricColumn> GetOverallMetricColumns()
{
yield return new MetricColumn("AccuracyMicro", MulticlassClassificationEvaluator.AccuracyMicro);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Accuracy(micro-avg): 0.936709
Accuracy(macro-avg): 0.942857
Log-loss: 0.285741
Log-loss reduction: 0.737254
Top K accuracy(All K)K=1: 0.936709
antoniovs1029 marked this conversation as resolved.
Show resolved Hide resolved
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

Confusion table
||========================
Expand All @@ -37,6 +41,10 @@ Accuracy(micro-avg): 0.957746
Accuracy(macro-avg): 0.953030
Log-loss: 0.160970
Log-loss reduction: 0.851729
Top K accuracy(All K)K=1: 0.957746
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ Accuracy(micro-avg): 0.936709
Accuracy(macro-avg): 0.942857
Log-loss: 0.285741
Log-loss reduction: 0.737254
Top K accuracy(All K)K=1: 0.936709
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

Confusion table
||========================
Expand All @@ -37,6 +41,10 @@ Accuracy(micro-avg): 0.957746
Accuracy(macro-avg): 0.953030
Log-loss: 0.160970
Log-loss reduction: 0.851729
Top K accuracy(All K)K=1: 0.957746
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Accuracy(micro-avg): 0.973333
Accuracy(macro-avg): 0.973333
Log-loss: 0.161048
Log-loss reduction: 0.853408
Top K accuracy(All K)K=1: 0.973333
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Accuracy(micro-avg): 0.973333
Accuracy(macro-avg): 0.973333
Log-loss: 0.161048
Log-loss reduction: 0.853408
Top K accuracy(All K)K=1: 0.973333
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Accuracy(micro-avg): 0.629834
Accuracy(macro-avg): 0.500000
Log-loss: 34.538776
Log-loss reduction: -51.407404
Top K accuracy(All K)K=1: 0.629834
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000

Confusion table
||======================
Expand All @@ -29,6 +32,9 @@ Accuracy(micro-avg): 0.682493
Accuracy(macro-avg): 0.500000
Log-loss: 34.538776
Log-loss reduction: -54.264136
Top K accuracy(All K)K=1: 0.682493
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Accuracy(micro-avg): 0.655222
Accuracy(macro-avg): 0.500000
Log-loss: 34.538776
Log-loss reduction: -52.618809
Top K accuracy(All K)K=1: 0.655222
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Accuracy(micro-avg): 0.962025
Accuracy(macro-avg): 0.965079
Log-loss: 0.129858
Log-loss reduction: 0.880592
Top K accuracy(All K)K=1: 0.962025
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

Confusion table
||========================
Expand All @@ -39,6 +43,10 @@ Accuracy(micro-avg): 0.971831
Accuracy(macro-avg): 0.966667
Log-loss: 0.125563
Log-loss reduction: 0.884343
Top K accuracy(All K)K=1: 0.971831
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ Accuracy(micro-avg): 0.980000
Accuracy(macro-avg): 0.980000
Log-loss: 0.095534
Log-loss reduction: 0.913041
Top K accuracy(All K)K=1: 0.980000
Top K accuracy(All K)K=2: 1.000000
Top K accuracy(All K)K=3: 1.000000
Top K accuracy(All K)K=4: 1.000000

OVERALL RESULTS
---------------------------------------
Expand Down
Loading